Home / Class/ dst_n Class — pytorch Architecture

dst_n Class — pytorch Architecture

Architecture documentation for the dst_n class in vec512_mask.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/cpu/vec/vec512/vec512_mask.h lines 12–45

template <typename T, int dst_n, typename mask_t, int mask_n>
struct VecMaskLoad<
    T,
    dst_n,
    mask_t,
    mask_n,
    typename std::enable_if_t<
        (mask_n == dst_n * 2 && dst_n >= 1) &&
            (std::is_same_v<T, float> || std::is_same_v<T, int32_t>),
        void>> {
  static inline VectorizedN<T, dst_n> apply(
      const T* ptr,
      const VecMask<mask_t, mask_n>& vec_mask) {
    at::vec::Vectorized<T> zero_vec(0);
    auto all_ones = _mm512_set1_epi32(0xFFFFFFFF);
    VectorizedN<mask_t, 2> tmp_vec;
    VectorizedN<T, dst_n> result;
    for (int i = 0; i < dst_n; i++) {
      tmp_vec[0] = vec_mask[2 * i];
      tmp_vec[1] = vec_mask[2 * i + 1];
      auto int64_mask = VecMask<mask_t, 2>(tmp_vec).template cast<int64_t, 2>();
      auto int_mask = int64_mask.template cast<int, 1>()[0];
      auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ);
      if constexpr (std::is_same_v<T, float>) {
        result[i] = Vectorized<T>(_mm512_mask_loadu_ps(
            zero_vec, mmask, ptr + i * Vectorized<T>::size()));
      } else {
        result[i] = Vectorized<T>(_mm512_mask_loadu_epi32(
            zero_vec, mmask, ptr + i * Vectorized<T>::size()));
      }
    }
    return result;
  }
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free