_exp_reduce_sum_fusion_kernel Class — pytorch Architecture
Architecture documentation for the _exp_reduce_sum_fusion_kernel class in FlashAttentionKernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/FlashAttentionKernel.cpp lines 86–123
template <typename T1, typename T2>
inline void _exp_reduce_sum_fusion_kernel(
T1* a,
const int& size,
T2* out,
T1& val) {
auto vec_size = vec::Vectorized<T1>::size();
auto vec_max = vec::Vectorized<T1>(val);
T1 tmp_sum = 0;
auto vec_tmp_sum = vec::Vectorized<T1>(tmp_sum);
for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
auto tmp0 = vec::Vectorized<T1>::loadu(a + i);
auto tmp1 = tmp0 - vec_max;
Vectorized<T1> tmp2;
if constexpr (std::is_same_v<T1, float> &&
(std::is_same_v<T2, at::BFloat16> || std::is_same_v<T2, at::Half>))
{
tmp2 = tmp1.fexp_u20();
} else {
tmp2 = tmp1.exp_u20();
}
vec_tmp_sum += tmp2;
_store(out + i, tmp2);
}
tmp_sum = vec::vec_reduce_all<T1>(
[](vec::Vectorized<T1>& x, vec::Vectorized<T1>& y) {
return x + y;
},
vec_tmp_sum);
for (long i = vec_size * (size / vec_size); i < size; i++) {
auto tmp0 = a[i];
auto tmp1 = tmp0 - val;
auto tmp2 = exp(tmp1);
tmp_sum += tmp2;
out[i] = tmp2;
}
val = tmp_sum;
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free