q_split_size Class — pytorch Architecture
Architecture documentation for the q_split_size class in FlashAttentionKernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/FlashAttentionKernel.cpp lines 307–755
template <typename scalar_t, typename mask_t, int64_t q_split_size, int64_t kv_split_size, bool with_pack=false>
void cpu_flash_attention(
const Tensor& output,
const Tensor& logsumexp,
const at::Tensor& q,
const at::Tensor& k,
const at::Tensor& v,
double dropout_p,
bool is_causal,
std::optional<Tensor> attn_mask,
std::optional<double> scale) {
// Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
// -> (Batch x Q_seq_len x Num_heads x Dim_per_head)
// Key (Batch x KV_num_heads x KV_seq_len x Dim_per_head)
// -> (Batch x KV_seq_len x KV_num_heads x Dim_per_head)
// Value (Batch x KV_num_heads x KV_seq_len x Dim_per_head)
// -> (Batch x KV_seq_len x KV_num_heads x Dim_per_head)
at::Tensor query = q.transpose(1, 2);
at::Tensor key = k.transpose(1, 2);
at::Tensor value = v.transpose(1, 2);
constexpr bool is_reduced_type = is_reduced_floating_point_v<scalar_t>;
using accum_t = at::opmath_type<scalar_t>;
using Vec = vec::Vectorized<accum_t>;
accum_t scaling_factor =
sdp::calculate_scale(query, scale).expect_float();
// Sizes
TORCH_CHECK((query.size(3) == value.size(3)) && (key.size(3) == value.size(3)),
"scaled_dot_product_attention_flash_attention: Q/K/V should have the same head size");
int64_t batchSize = query.size(0);
int64_t qSize = query.size(1);
int64_t kvSize = value.size(1);
int64_t num_head = query.size(2);
int64_t kv_num_head = key.size(2);
int64_t repeat_factor = num_head / kv_num_head;
int64_t headSize = query.size(3);
bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel();
if (has_attn_mask) {
reshape_attn_mask_to_4d(attn_mask.value(), batchSize, num_head, qSize, kvSize);
}
// Strides
int64_t qStrideB = query.stride(0);
int64_t qStrideM = query.stride(1);
int64_t qStrideH = query.stride(2);
int64_t kStrideB = key.stride(0);
int64_t kStrideN = key.stride(1);
int64_t kStrideH = key.stride(2);
int64_t vStrideB = value.stride(0);
int64_t vStrideN = value.stride(1);
int64_t vStrideH = value.stride(2);
int64_t oStrideB = output.stride(0);
int64_t oStrideM = output.stride(1);
int64_t oStrideH = output.stride(2);
int64_t lStrideB = logsumexp.stride(0);
int64_t lStrideM = logsumexp.stride(1);
int64_t lStrideH = logsumexp.stride(2);
int64_t mStrideB =
(has_attn_mask && attn_mask.value().size(0) > 1)
? attn_mask.value().stride(0)
: 0;
int64_t mStrideH =
(has_attn_mask && attn_mask.value().size(1) > 1)
? attn_mask.value().stride(1)
: 0;
int64_t mStrideM =
(has_attn_mask && attn_mask.value().size(2) > 1)
? attn_mask.value().stride(2)
: 0;
int64_t mStrideN =
(has_attn_mask && attn_mask.value().size(3) > 1)
? attn_mask.value().stride(3)
: 0;
int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size;
int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size;
int64_t qSlice = (qSize + qSplitSize - 1) / qSplitSize;
int64_t kvSlice = (kvSize + kvSplitSize - 1) / kvSplitSize;
int64_t kvTail = (kvSize - 1) % kvSplitSize + 1;
int64_t num_thread = at::get_num_threads();
const auto dtype = query.scalar_type();
const auto accumulate_dtype = toOpMathType(dtype);
// Whether pack is needed
bool need_pack = false;
if (with_pack) {
// BFloat16 requires larger size as the fallback implementation
// mkl_gemm_bf16bf16f32 is faster than mkl_gemm_f16f16f32
int64_t thresh_size = (dtype == at::ScalarType::BFloat16) ? 64 : 16;
need_pack = kvSize >= thresh_size && qSize >= thresh_size;
// When the number of gemm is greater than the number of pack,
// the pack overhead can be overlapped.
if (need_pack) {
double pack_size = batchSize * kv_num_head * kvSize * headSize;
double qs_per_thread = (batchSize * num_head * qSlice + num_thread - 1) / num_thread;
double gemm_size_per_thread = qs_per_thread * qSplitSize *
(is_causal ? std::min(qSize, kvSize) : kvSize) * headSize;
need_pack = gemm_size_per_thread / pack_size >= (dtype == at::ScalarType::BFloat16 ? 4 : 1);
}
}
// Pad is needed for packing when K is not even
bool headSize_even = headSize % 2 == 0;
int64_t eheadSize = need_pack && !headSize_even ? headSize + 1: headSize;
int64_t ekvSplitSize = need_pack && (kvSplitSize % 2 != 0) ? kvSplitSize + 1 : kvSplitSize;
int64_t ekvTail = need_pack && (kvTail % 2 != 0) ? kvTail + 1 : kvTail;
// Allocate per thread temp buf (accumulate type)
int64_t size_per_thread =
/* qk */ qSplitSize * kvSplitSize +
/* qk_max */ qSplitSize +
/* qk_sum */ qSplitSize +
/* dst */ qSplitSize * headSize;
at::Tensor buf = at::empty({num_thread, size_per_thread}, query.options().dtype(accumulate_dtype));
at::Tensor buf_reduced = at::empty(
{num_thread,
qSplitSize,
is_reduced_type ? ekvSplitSize : 0},
query.options());
// Data ptrs
const scalar_t* q_data = query.const_data_ptr<scalar_t>();
const scalar_t* k_data = key.const_data_ptr<scalar_t>();
const scalar_t* v_data = value.const_data_ptr<scalar_t>();
mask_t* mask_data = has_attn_mask
? attn_mask.value().data_ptr<mask_t>()
: nullptr;
scalar_t* out_data = output.data_ptr<scalar_t>();
accum_t* lse_data = logsumexp.data_ptr<accum_t>();
accum_t* buf_data = buf.data_ptr<accum_t>();
scalar_t* buf_reduced_data = is_reduced_type ? buf_reduced.data_ptr<scalar_t>() : nullptr;
// Buffer to store padding query and packing key/value
scalar_t* key_reorder_ptr = nullptr;
scalar_t* value_reorder_ptr = nullptr;
scalar_t* query_padding_ptr = nullptr;
int64_t kv_padding_size = (kvSize - 1) / kvSplitSize * ekvSplitSize + ekvTail;
at::Tensor key_t_reorder;
at::Tensor value_t_reorder;
at::Tensor qeury_t_padding;
if (need_pack) {
key_t_reorder = at::empty(
{batchSize, kv_num_head, eheadSize, kvSize},
c10::CppTypeToScalarType<scalar_t>::value);
value_t_reorder = at::empty(
{batchSize, kv_num_head, kv_padding_size, headSize},
c10::CppTypeToScalarType<scalar_t>::value);
key_reorder_ptr = key_t_reorder.data_ptr<scalar_t>();
value_reorder_ptr = value_t_reorder.data_ptr<scalar_t>();
}
if (!headSize_even && need_pack) {
qeury_t_padding = at::empty(
{num_thread, qSplitSize, eheadSize},
c10::CppTypeToScalarType<scalar_t>::value);
query_padding_ptr = qeury_t_padding.data_ptr<scalar_t>();
}
// Reorder K, V
if (need_pack) {
at::Tensor tranpose_t_reorder = at::empty(
{num_thread, kvSplitSize, headSize},
c10::CppTypeToScalarType<scalar_t>::value);
scalar_t* transpose_buffer_ptr = tranpose_t_reorder.data_ptr<scalar_t>();
at::parallel_for(0, batchSize * kv_num_head * kvSlice, 1, [&](int64_t begin, int64_t end) {
int ompIdx = at::get_thread_num();
int64_t i = 0, kv_j = 0, l = 0, n = 0;
scalar_t* transpose_ptr = transpose_buffer_ptr + ompIdx * kvSplitSize * headSize;
at::native::data_index_init(begin, i, batchSize, kv_j, kv_num_head, l, kvSlice);
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
n = l * kvSplitSize;
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
// transpose [kvBlockSize, headSize] -> [headSize, kvBlockSize]
utils::transpose<uint16_t>(
kvBlockSize,
headSize,
/* src_ptr */
reinterpret_cast<const uint16_t*>(k_data + i * kStrideB + kv_j * kStrideH + n * kStrideN),
/* ld_src */ kStrideN,
/* dst */ reinterpret_cast<uint16_t*>(transpose_ptr),
/* ld_dst */ kvBlockSize);
// Pack [headSize, kvBlockSize]
at::vec::pack_vnni2(
/* src */ reinterpret_cast<const uint16_t*>(transpose_ptr),
/* dst */ reinterpret_cast<uint16_t*>(key_reorder_ptr + i * kv_num_head * eheadSize * kvSize +
kv_j * eheadSize * kvSize + n * eheadSize),
/* ld_src */ kvBlockSize,
/* K */ headSize,
/* N */ kvBlockSize);
// Pack [kvBlockSize, headSize]
at::vec::pack_vnni2(
/* src */ reinterpret_cast<const uint16_t*>(v_data + i * vStrideB + kv_j * vStrideH + n * vStrideN),
/* dst */ reinterpret_cast<uint16_t*>(value_reorder_ptr +
i * kv_num_head * kv_padding_size * headSize +
kv_j * kv_padding_size * headSize + n * headSize),
/* ld_src */ vStrideN,
/* K */ kvBlockSize,
/* N */ headSize);
// Move to the next query
at::native::data_index_step(i, batchSize, kv_j, kv_num_head, l, kvSlice);
}
});
}
at::parallel_for(0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) {
int64_t i = 0, j = 0, k = 0;
data_index_init(begin, i, batchSize, j, num_head, k, qSlice);
int ompIdx = at::get_thread_num();
accum_t* buf_ptr = buf_data + ompIdx * size_per_thread;
accum_t* qk_data = buf_ptr;
accum_t* qk_max_data = qk_data + qSplitSize * kvSplitSize;
accum_t* qk_sum_data = qk_max_data + qSplitSize;
accum_t* dst_data = qk_sum_data + qSplitSize;
scalar_t* qk_reduced_data = is_reduced_type ? buf_reduced_data + ompIdx * qSplitSize * ekvSplitSize : nullptr;
scalar_t* query_t_padding_ptr = (!headSize_even && need_pack)
? query_padding_ptr + ompIdx * qSplitSize * eheadSize
: nullptr;
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
int64_t m = k * qSplitSize;
int64_t qBlockSize = std::min(qSplitSize, qSize - m);
int64_t kv_j = j / repeat_factor;
// Initialize max and sum
fill_stub(qk_max_data,
-std::numeric_limits<accum_t>::infinity(), qBlockSize);
fill_stub(qk_sum_data,
static_cast<accum_t>(0), qBlockSize);
int64_t num_keys = is_causal ? std::min(m + qBlockSize, kvSize) : kvSize;
if (!headSize_even && need_pack) {
// Pad query if headSize is not even
// [qBlockSize, headSize] -> [qBlockSize, eheadSize]
copy_value_with_pad<scalar_t>(
q_data + i * qStrideB + j * qStrideH + m * qStrideM,
query_t_padding_ptr,
qBlockSize,
headSize,
qBlockSize,
eheadSize,
qStrideM
);
}
for (int64_t n = 0; n < num_keys; n += kvSplitSize) {
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
int64_t ekvBlockSize = (need_pack && kvBlockSize % 2 != 0) ? kvBlockSize + 1 : kvBlockSize;
// Calculate scale * q @ k.T
if (need_pack) {
if constexpr (is_reduced_floating_point_v<scalar_t>) {
cpublas::brgemm(
qBlockSize,
kvBlockSize,
eheadSize,
headSize_even ? qStrideM : eheadSize,
kvBlockSize,
kvBlockSize,
false,
!headSize_even
? query_t_padding_ptr
: q_data + i * qStrideB + j * qStrideH + m * qStrideM,
key_reorder_ptr + i * kv_num_head * eheadSize * kvSize +
kv_j * eheadSize * kvSize + n * eheadSize,
qk_data);
}
} else {
cpublas::gemm(
TransposeType::Transpose,
TransposeType::NoTranspose,
kvBlockSize,
qBlockSize,
headSize,
static_cast<accum_t>(1),
k_data + i * kStrideB + kv_j * kStrideH +
n * kStrideN,
kStrideN,
q_data + i * qStrideB + j * qStrideH +
m * qStrideM,
qStrideM,
static_cast<accum_t>(0),
qk_data,
kvBlockSize);
}
// Apply causal mask, fill unused with -inf
if (is_causal && num_keys - n <= kvSplitSize) {
for (const auto row : c10::irange(qBlockSize)) {
int64_t last_col = m + row - n;
accum_t* row_ptr = qk_data + row * kvBlockSize;
fill_stub(row_ptr + last_col + 1,
-std::numeric_limits<accum_t>::infinity(),
kvBlockSize - last_col - 1);
}
}
// Update attention weights with attention mask
// And apply scaling factor
// qk <- qk * scaling + attn_mask
if (has_attn_mask) {
for (int64_t row = 0; row < qBlockSize; ++row) {
#if __GNUC__ == 11 && defined(__ARM_FEATURE_SVE)
_scale_attn_mask_fusion_kernel(
qk_data + row * kvBlockSize,
mask_data + i * mStrideB + j * mStrideH +
(m + row) * mStrideM + (mStrideN == 0 ? 0 : n),
kvBlockSize,
qk_data + row * kvBlockSize,
scaling_factor,
mStrideN == 0);
#else
if (mStrideN == 0) {
_scale_attn_mask_fusion_kernel</*is_stride_0*/ true>(
qk_data + row * kvBlockSize,
mask_data + i * mStrideB + j * mStrideH +
(m + row) * mStrideM,
kvBlockSize,
qk_data + row * kvBlockSize,
scaling_factor);
} else {
_scale_attn_mask_fusion_kernel</*is_stride_0*/ false>(
qk_data + row * kvBlockSize,
mask_data + i * mStrideB + j * mStrideH +
(m + row) * mStrideM + n,
kvBlockSize,
qk_data + row * kvBlockSize,
scaling_factor);
}
#endif
}
}
// Update coefficients with Softmax
accum_t tmp_max = 0, tmp_sum = 0, exp_tmp = 0;
for (int64_t row = 0; row < qBlockSize; ++row) {
if (has_attn_mask) {
// max per row
tmp_max = at::vec::reduce_all<accum_t>(
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); },
qk_data + row * kvBlockSize,
kvBlockSize);
} else {
// apply scaling factor and max per row in fusion
_mul_reduce_max_fusion_kernel(
qk_data + row * kvBlockSize,
scaling_factor,
kvBlockSize,
qk_data + row * kvBlockSize,
tmp_max);
}
tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max;
if (tmp_max == -std::numeric_limits<accum_t>::infinity()) {
// to avoid `nan = exp2f(-inf - (-inf))`
fill_stub(conditional_data_ptr(qk_data, qk_reduced_data) + row * ekvBlockSize,
static_cast<scalar_t>(0), kvBlockSize);
} else {
tmp_sum = tmp_max;
// qk <- exp(qk - max) and sum per row
_exp_reduce_sum_fusion_kernel(
qk_data + row * kvBlockSize, kvBlockSize,
conditional_data_ptr(qk_data, qk_reduced_data) + row * ekvBlockSize,
tmp_sum);
// exp_tmp <- exp(max[row] - max)
exp_tmp = std::exp(qk_max_data[row] - tmp_max);
// sum[row] <- sum + exp_tmp * sum[row]
qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row];
// max[row] <- max
qk_max_data[row] = tmp_max;
// dst <- dst * exp_tmp
if (n > 0) {
vec::map<accum_t>(
[exp_tmp](Vec x) { return x * Vec(exp_tmp); },
dst_data + row * headSize,
dst_data + row * headSize,
headSize);
}
}
if (need_pack && kvBlockSize % 2 != 0) {
// Pad: [qSplitSize, kvBlockSize] -> [qSplitSize, kvBlockSize + 1]
*(qk_reduced_data + row * (1 + kvBlockSize) + kvBlockSize) = scalar_t(0);
}
}
// Calculate Softmax(q @ k.T) @ v
if (need_pack) {
int64_t psize = n / kvSplitSize * ekvSplitSize;
if constexpr (is_reduced_floating_point_v<scalar_t>) {
cpublas::brgemm(
qBlockSize,
headSize,
ekvBlockSize,
ekvBlockSize,
headSize,
headSize,
n > 0,
qk_reduced_data,
value_reorder_ptr +
i * kv_num_head * kv_padding_size * headSize +
kv_j * kv_padding_size * headSize + psize * headSize,
dst_data);
}
} else {
cpublas::gemm(
TransposeType::NoTranspose,
TransposeType::NoTranspose,
headSize,
qBlockSize,
kvBlockSize,
static_cast<accum_t>(1),
v_data + i * vStrideB + kv_j * vStrideH +
n * vStrideN,
vStrideN,
conditional_data_ptr(qk_data, qk_reduced_data),
kvBlockSize,
n == 0 ? static_cast<accum_t>(0) : static_cast<accum_t>(1),
dst_data,
headSize);
}
}
// dst <- dst / sum[row]
// reorder MHA output with strides
for (int64_t row = 0; row < qBlockSize; ++row) {
// Row sums for full masked out rows are 0, we set them to 1
// in order to avoid NaNs in the output and instead set fully
// masked out rows to 0
qk_max_data[row] = qk_max_data[row] == -std::numeric_limits<accum_t>::infinity() ? 0 : qk_max_data[row];
qk_sum_data[row] = qk_sum_data[row] == 0 ? 1 : qk_sum_data[row];
accum_t sum_reciprocal = 1 / qk_sum_data[row];
vec::map<scalar_t>(
[sum_reciprocal](Vec x) { return x * Vec(sum_reciprocal); },
out_data + i * oStrideB + j * oStrideH + m * oStrideM + row * oStrideM,
dst_data + row * headSize,
headSize);
}
// Store logsumexp for backward
accum_t* lse_ptr = lse_data + i * lStrideB + j * lStrideH + m * lStrideM;
for (const auto row : c10::irange(qBlockSize)) {
lse_ptr[row * lStrideM] = qk_max_data[row]
+ std::log(qk_sum_data[row]);
}
// Move to the next query
data_index_step(i, batchSize, j, num_head, k, qSlice);
}
if (need_pack) {
cpublas::brgemm_release();
}
});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free