Home / Class/ q_split_size Class — pytorch Architecture

q_split_size Class — pytorch Architecture

Architecture documentation for the q_split_size class in FlashAttentionKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/FlashAttentionKernel.cpp lines 307–755

template <typename scalar_t, typename mask_t, int64_t q_split_size, int64_t kv_split_size, bool with_pack=false>
void cpu_flash_attention(
    const Tensor& output,
    const Tensor& logsumexp,
    const at::Tensor& q,
    const at::Tensor& k,
    const at::Tensor& v,
    double dropout_p,
    bool is_causal,
    std::optional<Tensor> attn_mask,
    std::optional<double> scale) {
  // Query (Batch x Num_heads    x Q_seq_len    x Dim_per_head)
  //    -> (Batch x Q_seq_len    x Num_heads    x Dim_per_head)
  // Key   (Batch x KV_num_heads x KV_seq_len   x Dim_per_head)
  //    -> (Batch x KV_seq_len   x KV_num_heads x Dim_per_head)
  // Value (Batch x KV_num_heads x KV_seq_len   x Dim_per_head)
  //    -> (Batch x KV_seq_len   x KV_num_heads x Dim_per_head)
  at::Tensor query = q.transpose(1, 2);
  at::Tensor key = k.transpose(1, 2);
  at::Tensor value = v.transpose(1, 2);

  constexpr bool is_reduced_type = is_reduced_floating_point_v<scalar_t>;
  using accum_t = at::opmath_type<scalar_t>;
  using Vec = vec::Vectorized<accum_t>;
  accum_t scaling_factor =
      sdp::calculate_scale(query, scale).expect_float();

  // Sizes
  TORCH_CHECK((query.size(3) == value.size(3)) && (key.size(3) == value.size(3)),
        "scaled_dot_product_attention_flash_attention: Q/K/V should have the same head size");
  int64_t batchSize = query.size(0);
  int64_t qSize = query.size(1);
  int64_t kvSize = value.size(1);
  int64_t num_head = query.size(2);
  int64_t kv_num_head = key.size(2);
  int64_t repeat_factor = num_head / kv_num_head;
  int64_t headSize = query.size(3);

  bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel();
  if (has_attn_mask) {
    reshape_attn_mask_to_4d(attn_mask.value(), batchSize, num_head, qSize, kvSize);
  }

  // Strides
  int64_t qStrideB = query.stride(0);
  int64_t qStrideM = query.stride(1);
  int64_t qStrideH = query.stride(2);
  int64_t kStrideB = key.stride(0);
  int64_t kStrideN = key.stride(1);
  int64_t kStrideH = key.stride(2);
  int64_t vStrideB = value.stride(0);
  int64_t vStrideN = value.stride(1);
  int64_t vStrideH = value.stride(2);
  int64_t oStrideB = output.stride(0);
  int64_t oStrideM = output.stride(1);
  int64_t oStrideH = output.stride(2);
  int64_t lStrideB = logsumexp.stride(0);
  int64_t lStrideM = logsumexp.stride(1);
  int64_t lStrideH = logsumexp.stride(2);
  int64_t mStrideB =
      (has_attn_mask && attn_mask.value().size(0) > 1)
      ? attn_mask.value().stride(0)
      : 0;
  int64_t mStrideH =
      (has_attn_mask && attn_mask.value().size(1) > 1)
      ? attn_mask.value().stride(1)
      : 0;
  int64_t mStrideM =
      (has_attn_mask && attn_mask.value().size(2) > 1)
      ? attn_mask.value().stride(2)
      : 0;
  int64_t mStrideN =
      (has_attn_mask && attn_mask.value().size(3) > 1)
      ? attn_mask.value().stride(3)
      : 0;

  int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size;
  int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size;
  int64_t qSlice = (qSize + qSplitSize - 1) / qSplitSize;
  int64_t kvSlice = (kvSize + kvSplitSize - 1) / kvSplitSize;
  int64_t kvTail = (kvSize - 1) % kvSplitSize + 1;
  int64_t num_thread = at::get_num_threads();

  const auto dtype = query.scalar_type();
  const auto accumulate_dtype = toOpMathType(dtype);

  // Whether pack is needed
  bool need_pack = false;
  if (with_pack) {
    // BFloat16 requires larger size as the fallback implementation
    // mkl_gemm_bf16bf16f32 is faster than mkl_gemm_f16f16f32
    int64_t thresh_size = (dtype == at::ScalarType::BFloat16) ? 64 : 16;
    need_pack = kvSize >= thresh_size && qSize >= thresh_size;
    // When the number of gemm is greater than the number of pack,
    // the pack overhead can be overlapped.
    if (need_pack) {
      double pack_size = batchSize * kv_num_head * kvSize * headSize;
      double qs_per_thread = (batchSize * num_head * qSlice + num_thread - 1) / num_thread;
      double gemm_size_per_thread = qs_per_thread * qSplitSize *
          (is_causal ? std::min(qSize, kvSize) : kvSize) * headSize;
      need_pack = gemm_size_per_thread / pack_size >= (dtype == at::ScalarType::BFloat16 ? 4 : 1);
    }
  }

  // Pad is needed for packing when K is not even
  bool headSize_even = headSize % 2 == 0;
  int64_t eheadSize = need_pack && !headSize_even ? headSize + 1: headSize;
  int64_t ekvSplitSize = need_pack && (kvSplitSize % 2 != 0) ? kvSplitSize + 1 : kvSplitSize;
  int64_t ekvTail = need_pack && (kvTail % 2 != 0) ? kvTail + 1 : kvTail;

  // Allocate per thread temp buf (accumulate type)
  int64_t size_per_thread =
      /* qk     */ qSplitSize * kvSplitSize +
      /* qk_max */ qSplitSize +
      /* qk_sum */ qSplitSize +
      /* dst    */ qSplitSize * headSize;

  at::Tensor buf = at::empty({num_thread, size_per_thread}, query.options().dtype(accumulate_dtype));
  at::Tensor buf_reduced = at::empty(
    {num_thread,
     qSplitSize,
     is_reduced_type ? ekvSplitSize : 0},
     query.options());

  // Data ptrs
  const scalar_t* q_data = query.const_data_ptr<scalar_t>();
  const scalar_t* k_data = key.const_data_ptr<scalar_t>();
  const scalar_t* v_data = value.const_data_ptr<scalar_t>();
  mask_t* mask_data = has_attn_mask
      ? attn_mask.value().data_ptr<mask_t>()
      : nullptr;
  scalar_t* out_data = output.data_ptr<scalar_t>();
  accum_t* lse_data = logsumexp.data_ptr<accum_t>();
  accum_t* buf_data = buf.data_ptr<accum_t>();
  scalar_t* buf_reduced_data = is_reduced_type ? buf_reduced.data_ptr<scalar_t>() : nullptr;

  // Buffer to store padding query and packing key/value
  scalar_t* key_reorder_ptr = nullptr;
  scalar_t* value_reorder_ptr = nullptr;
  scalar_t* query_padding_ptr = nullptr;
  int64_t kv_padding_size = (kvSize - 1) / kvSplitSize * ekvSplitSize + ekvTail;
  at::Tensor key_t_reorder;
  at::Tensor value_t_reorder;
  at::Tensor qeury_t_padding;
  if (need_pack) {
    key_t_reorder = at::empty(
      {batchSize, kv_num_head, eheadSize, kvSize},
      c10::CppTypeToScalarType<scalar_t>::value);
    value_t_reorder = at::empty(
      {batchSize, kv_num_head, kv_padding_size, headSize},
      c10::CppTypeToScalarType<scalar_t>::value);
    key_reorder_ptr = key_t_reorder.data_ptr<scalar_t>();
    value_reorder_ptr = value_t_reorder.data_ptr<scalar_t>();
  }

  if (!headSize_even && need_pack) {
    qeury_t_padding = at::empty(
      {num_thread, qSplitSize, eheadSize},
      c10::CppTypeToScalarType<scalar_t>::value);
    query_padding_ptr = qeury_t_padding.data_ptr<scalar_t>();
  }

  // Reorder K, V
  if (need_pack) {
    at::Tensor tranpose_t_reorder = at::empty(
      {num_thread, kvSplitSize, headSize},
      c10::CppTypeToScalarType<scalar_t>::value);
    scalar_t* transpose_buffer_ptr = tranpose_t_reorder.data_ptr<scalar_t>();
    at::parallel_for(0, batchSize * kv_num_head * kvSlice, 1, [&](int64_t begin, int64_t end) {
        int ompIdx = at::get_thread_num();
        int64_t i = 0, kv_j = 0, l = 0, n = 0;
        scalar_t* transpose_ptr = transpose_buffer_ptr + ompIdx * kvSplitSize * headSize;
        at::native::data_index_init(begin, i, batchSize, kv_j, kv_num_head, l, kvSlice);
        for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
          n = l * kvSplitSize;
          int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);

          // transpose [kvBlockSize, headSize] -> [headSize, kvBlockSize]
          utils::transpose<uint16_t>(
              kvBlockSize,
              headSize,
              /* src_ptr */
              reinterpret_cast<const uint16_t*>(k_data + i * kStrideB + kv_j * kStrideH + n * kStrideN),
              /* ld_src */ kStrideN,
              /* dst */ reinterpret_cast<uint16_t*>(transpose_ptr),
              /* ld_dst */ kvBlockSize);

          // Pack [headSize, kvBlockSize]
          at::vec::pack_vnni2(
            /* src */ reinterpret_cast<const uint16_t*>(transpose_ptr),
            /* dst */ reinterpret_cast<uint16_t*>(key_reorder_ptr + i * kv_num_head * eheadSize * kvSize +
                    kv_j * eheadSize * kvSize + n * eheadSize),
            /* ld_src */ kvBlockSize,
            /* K */ headSize,
            /* N */ kvBlockSize);

          // Pack [kvBlockSize, headSize]
          at::vec::pack_vnni2(
            /* src */ reinterpret_cast<const uint16_t*>(v_data + i * vStrideB + kv_j * vStrideH + n * vStrideN),
            /* dst */ reinterpret_cast<uint16_t*>(value_reorder_ptr +
                    i * kv_num_head * kv_padding_size * headSize +
                    kv_j * kv_padding_size * headSize + n * headSize),
            /* ld_src */ vStrideN,
            /* K */ kvBlockSize,
            /* N */ headSize);

          // Move to the next query
          at::native::data_index_step(i, batchSize, kv_j, kv_num_head, l, kvSlice);
        }
      });
  }

  at::parallel_for(0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) {
    int64_t i = 0, j = 0, k = 0;
    data_index_init(begin, i, batchSize, j, num_head, k, qSlice);
    int ompIdx = at::get_thread_num();
    accum_t* buf_ptr = buf_data + ompIdx * size_per_thread;
    accum_t* qk_data = buf_ptr;
    accum_t* qk_max_data = qk_data + qSplitSize * kvSplitSize;
    accum_t* qk_sum_data = qk_max_data + qSplitSize;
    accum_t* dst_data = qk_sum_data + qSplitSize;
    scalar_t* qk_reduced_data = is_reduced_type ? buf_reduced_data + ompIdx * qSplitSize * ekvSplitSize : nullptr;
    scalar_t* query_t_padding_ptr = (!headSize_even && need_pack)
            ? query_padding_ptr + ompIdx * qSplitSize * eheadSize
            : nullptr;

    for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
      int64_t m = k * qSplitSize;
      int64_t qBlockSize = std::min(qSplitSize, qSize - m);
      int64_t kv_j = j / repeat_factor;
      // Initialize max and sum
      fill_stub(qk_max_data,
          -std::numeric_limits<accum_t>::infinity(), qBlockSize);
      fill_stub(qk_sum_data,
          static_cast<accum_t>(0), qBlockSize);
      int64_t num_keys = is_causal ? std::min(m + qBlockSize, kvSize) : kvSize;
      if (!headSize_even && need_pack) {
        // Pad query if headSize is not even
        // [qBlockSize, headSize] -> [qBlockSize, eheadSize]
        copy_value_with_pad<scalar_t>(
          q_data + i * qStrideB + j * qStrideH + m * qStrideM,
          query_t_padding_ptr,
          qBlockSize,
          headSize,
          qBlockSize,
          eheadSize,
          qStrideM
        );
      }
      for (int64_t n = 0; n < num_keys; n += kvSplitSize) {
        int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
        int64_t ekvBlockSize = (need_pack && kvBlockSize % 2 != 0) ? kvBlockSize + 1 : kvBlockSize;
        // Calculate scale * q @ k.T
        if (need_pack) {
          if constexpr (is_reduced_floating_point_v<scalar_t>) {
            cpublas::brgemm(
                qBlockSize,
                kvBlockSize,
                eheadSize,
                headSize_even ? qStrideM : eheadSize,
                kvBlockSize,
                kvBlockSize,
                false,
                !headSize_even
                    ? query_t_padding_ptr
                    : q_data + i * qStrideB + j * qStrideH + m * qStrideM,
                key_reorder_ptr + i * kv_num_head * eheadSize * kvSize +
                    kv_j * eheadSize * kvSize + n * eheadSize,
                qk_data);
          }
        } else {
          cpublas::gemm(
            TransposeType::Transpose,
            TransposeType::NoTranspose,
            kvBlockSize,
            qBlockSize,
            headSize,
            static_cast<accum_t>(1),
            k_data + i * kStrideB + kv_j * kStrideH +
                n * kStrideN,
            kStrideN,
            q_data + i * qStrideB + j * qStrideH +
                m * qStrideM,
            qStrideM,
            static_cast<accum_t>(0),
            qk_data,
            kvBlockSize);
        }
        // Apply causal mask, fill unused with -inf
        if (is_causal && num_keys - n <= kvSplitSize) {
          for (const auto row : c10::irange(qBlockSize)) {
            int64_t last_col = m + row - n;
            accum_t* row_ptr = qk_data + row * kvBlockSize;
            fill_stub(row_ptr + last_col + 1,
                -std::numeric_limits<accum_t>::infinity(),
                kvBlockSize - last_col - 1);
          }
        }
        // Update attention weights with attention mask
        // And apply scaling factor
        // qk <- qk * scaling + attn_mask
        if (has_attn_mask) {
          for (int64_t row = 0; row < qBlockSize; ++row) {
#if __GNUC__ == 11 && defined(__ARM_FEATURE_SVE)
              _scale_attn_mask_fusion_kernel(
                qk_data + row * kvBlockSize,
                mask_data + i * mStrideB + j * mStrideH +
                    (m + row) * mStrideM + (mStrideN == 0 ? 0 : n),
                kvBlockSize,
                qk_data + row * kvBlockSize,
                scaling_factor,
                mStrideN == 0);
#else
              if (mStrideN == 0) {
                _scale_attn_mask_fusion_kernel</*is_stride_0*/ true>(
                  qk_data + row * kvBlockSize,
                  mask_data + i * mStrideB + j * mStrideH +
                      (m + row) * mStrideM,
                  kvBlockSize,
                  qk_data + row * kvBlockSize,
                  scaling_factor);
              } else {
                _scale_attn_mask_fusion_kernel</*is_stride_0*/ false>(
                  qk_data + row * kvBlockSize,
                  mask_data + i * mStrideB + j * mStrideH +
                      (m + row) * mStrideM + n,
                  kvBlockSize,
                  qk_data + row * kvBlockSize,
                  scaling_factor);
              }
#endif
          }
        }
        // Update coefficients with Softmax
        accum_t tmp_max = 0, tmp_sum = 0, exp_tmp = 0;
        for (int64_t row = 0; row < qBlockSize; ++row) {
          if (has_attn_mask) {
            // max per row
            tmp_max = at::vec::reduce_all<accum_t>(
                [](Vec& x, Vec& y) { return at::vec::maximum(x, y); },
                qk_data + row * kvBlockSize,
                kvBlockSize);
          } else {
            // apply scaling factor and max per row in fusion
            _mul_reduce_max_fusion_kernel(
                qk_data + row * kvBlockSize,
                scaling_factor,
                kvBlockSize,
                qk_data + row * kvBlockSize,
                tmp_max);
          }
          tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max;
          if (tmp_max == -std::numeric_limits<accum_t>::infinity()) {
            // to avoid `nan = exp2f(-inf - (-inf))`
            fill_stub(conditional_data_ptr(qk_data, qk_reduced_data) + row * ekvBlockSize,
              static_cast<scalar_t>(0), kvBlockSize);
          } else {
            tmp_sum = tmp_max;
            // qk <- exp(qk - max) and sum per row
            _exp_reduce_sum_fusion_kernel(
                qk_data + row * kvBlockSize, kvBlockSize,
                conditional_data_ptr(qk_data, qk_reduced_data) + row * ekvBlockSize,
                tmp_sum);
            // exp_tmp <- exp(max[row] - max)
            exp_tmp = std::exp(qk_max_data[row] - tmp_max);
            // sum[row] <- sum + exp_tmp * sum[row]
            qk_sum_data[row] = tmp_sum + exp_tmp * qk_sum_data[row];
            // max[row] <- max
            qk_max_data[row] = tmp_max;
            // dst <- dst * exp_tmp
            if (n > 0) {
              vec::map<accum_t>(
                [exp_tmp](Vec x) { return x * Vec(exp_tmp); },
                dst_data + row * headSize,
                dst_data + row * headSize,
                headSize);
            }
          }
          if (need_pack && kvBlockSize % 2 != 0) {
            // Pad: [qSplitSize, kvBlockSize] -> [qSplitSize, kvBlockSize + 1]
            *(qk_reduced_data + row * (1 + kvBlockSize) + kvBlockSize) = scalar_t(0);
          }
        }
        // Calculate Softmax(q @ k.T) @ v
        if (need_pack) {
          int64_t psize = n / kvSplitSize * ekvSplitSize;
          if constexpr (is_reduced_floating_point_v<scalar_t>) {
            cpublas::brgemm(
                  qBlockSize,
                  headSize,
                  ekvBlockSize,
                  ekvBlockSize,
                  headSize,
                  headSize,
                  n > 0,
                  qk_reduced_data,
                  value_reorder_ptr +
                      i * kv_num_head * kv_padding_size * headSize +
                      kv_j * kv_padding_size * headSize + psize * headSize,
                  dst_data);
          }
        } else {
          cpublas::gemm(
            TransposeType::NoTranspose,
            TransposeType::NoTranspose,
            headSize,
            qBlockSize,
            kvBlockSize,
            static_cast<accum_t>(1),
            v_data + i * vStrideB + kv_j * vStrideH +
                n * vStrideN,
            vStrideN,
            conditional_data_ptr(qk_data, qk_reduced_data),
            kvBlockSize,
            n == 0 ? static_cast<accum_t>(0) : static_cast<accum_t>(1),
            dst_data,
            headSize);
        }
      }

      // dst <- dst / sum[row]
      // reorder MHA output with strides
      for (int64_t row = 0; row < qBlockSize; ++row) {
        // Row sums for full masked out rows are 0, we set them to 1
        // in order to avoid NaNs in the output and instead set fully
        // masked out rows to 0
        qk_max_data[row] = qk_max_data[row] == -std::numeric_limits<accum_t>::infinity() ? 0 : qk_max_data[row];
        qk_sum_data[row] = qk_sum_data[row] == 0 ? 1 : qk_sum_data[row];
        accum_t sum_reciprocal = 1 / qk_sum_data[row];
        vec::map<scalar_t>(
          [sum_reciprocal](Vec x) { return x * Vec(sum_reciprocal); },
          out_data + i * oStrideB + j * oStrideH + m * oStrideM + row * oStrideM,
          dst_data + row * headSize,
          headSize);
      }
      // Store logsumexp for backward
      accum_t* lse_ptr = lse_data + i * lStrideB + j * lStrideH + m * lStrideM;
      for (const auto row : c10::irange(qBlockSize)) {
        lse_ptr[row * lStrideM] = qk_max_data[row]
            + std::log(qk_sum_data[row]);
      }
      // Move to the next query
      data_index_step(i, batchSize, j, num_head, k, qSlice);
    }
    if (need_pack) {
      cpublas::brgemm_release();
    }
  });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free