_mul_reduce_max_fusion_kernel Class — pytorch Architecture
Architecture documentation for the _mul_reduce_max_fusion_kernel class in FlashAttentionKernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/FlashAttentionKernel.cpp lines 127–158
template <typename scalar_t>
inline void _mul_reduce_max_fusion_kernel(
const scalar_t* a,
const scalar_t& scale,
const int& size,
scalar_t* out,
scalar_t& max) {
auto vec_size = vec::Vectorized<scalar_t>::size();
auto vec_scale = vec::Vectorized<scalar_t>(scale);
scalar_t tmp_max = -std::numeric_limits<scalar_t>::infinity();
auto vec_tmp_max = vec::Vectorized<scalar_t>(tmp_max);
for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
auto tmp0 = vec::Vectorized<scalar_t>::loadu(a + i);
auto tmp1 = tmp0 * vec_scale;
vec_tmp_max = vec::maximum(vec_tmp_max, tmp1);
_store(out + i, tmp1);
}
for (long i = vec_size * (size / vec_size); i < size; i++) {
auto tmp0 = a[i];
auto tmp1 = tmp0 * scale;
tmp_max = std::max(tmp_max, tmp1);
out[i] = tmp1;
}
auto reduced_tmp_max = vec::vec_reduce_all<scalar_t>(
[](vec::Vectorized<scalar_t>& x, vec::Vectorized<scalar_t>& y) {
return vec::maximum(x, y);
},
vec_tmp_max);
// Guard against Q*K^T being NaN
max = std::isnan(reduced_tmp_max) ? std::numeric_limits<scalar_t>::quiet_NaN()
: std::max(tmp_max, reduced_tmp_max);
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free