Home / Class/ _mul_reduce_max_fusion_kernel Class — pytorch Architecture

_mul_reduce_max_fusion_kernel Class — pytorch Architecture

Architecture documentation for the _mul_reduce_max_fusion_kernel class in FlashAttentionKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/FlashAttentionKernel.cpp lines 127–158

template <typename scalar_t>
inline void _mul_reduce_max_fusion_kernel(
    const scalar_t* a,
    const scalar_t& scale,
    const int& size,
    scalar_t* out,
    scalar_t& max) {
  auto vec_size = vec::Vectorized<scalar_t>::size();
  auto vec_scale = vec::Vectorized<scalar_t>(scale);
  scalar_t tmp_max = -std::numeric_limits<scalar_t>::infinity();
  auto vec_tmp_max = vec::Vectorized<scalar_t>(tmp_max);
  for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
    auto tmp0 = vec::Vectorized<scalar_t>::loadu(a + i);
    auto tmp1 = tmp0 * vec_scale;
    vec_tmp_max = vec::maximum(vec_tmp_max, tmp1);
    _store(out + i, tmp1);
  }
  for (long i = vec_size * (size / vec_size); i < size; i++) {
    auto tmp0 = a[i];
    auto tmp1 = tmp0 * scale;
    tmp_max = std::max(tmp_max, tmp1);
    out[i] = tmp1;
  }
  auto reduced_tmp_max = vec::vec_reduce_all<scalar_t>(
      [](vec::Vectorized<scalar_t>& x, vec::Vectorized<scalar_t>& y) {
        return vec::maximum(x, y);
      },
      vec_tmp_max);
  // Guard against Q*K^T being NaN
  max = std::isnan(reduced_tmp_max) ? std::numeric_limits<scalar_t>::quiet_NaN()
                                    : std::max(tmp_max, reduced_tmp_max);
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free