Home / Class/ is_same_v Class — pytorch Architecture

is_same_v Class — pytorch Architecture

Architecture documentation for the is_same_v class in vec512_int.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/cpu/vec/vec512/vec512_int.h lines 727–1138

template <typename T>
class Vectorized8 : public Vectorizedi {
  static_assert(
      std::is_same_v<T, int8_t> || std::is_same_v<T, uint8_t>,
      "Only int8_t/uint8_t are supported");

 protected:
  static constexpr __m512i zero_vector{0, 0, 0, 0, 0, 0, 0, 0};
  static const Vectorized<T> ones;

 public:
  using value_type = T;
  static constexpr int size() {
    return 64;
  }
  using Vectorizedi::Vectorizedi;
  Vectorized8() {}
  Vectorized8(T v) {
    values = _mm512_set1_epi8(v);
  }
  Vectorized8(
      T val1,
      T val2,
      T val3,
      T val4,
      T val5,
      T val6,
      T val7,
      T val8,
      T val9,
      T val10,
      T val11,
      T val12,
      T val13,
      T val14,
      T val15,
      T val16,
      T val17,
      T val18,
      T val19,
      T val20,
      T val21,
      T val22,
      T val23,
      T val24,
      T val25,
      T val26,
      T val27,
      T val28,
      T val29,
      T val30,
      T val31,
      T val32,
      T val33,
      T val34,
      T val35,
      T val36,
      T val37,
      T val38,
      T val39,
      T val40,
      T val41,
      T val42,
      T val43,
      T val44,
      T val45,
      T val46,
      T val47,
      T val48,
      T val49,
      T val50,
      T val51,
      T val52,
      T val53,
      T val54,
      T val55,
      T val56,
      T val57,
      T val58,
      T val59,
      T val60,
      T val61,
      T val62,
      T val63,
      T val64) {
    values = _mm512_set_epi8(
        val64,
        val63,
        val62,
        val61,
        val60,
        val59,
        val58,
        val57,
        val56,
        val55,
        val54,
        val53,
        val52,
        val51,
        val50,
        val49,
        val48,
        val47,
        val46,
        val45,
        val44,
        val43,
        val42,
        val41,
        val40,
        val39,
        val38,
        val37,
        val36,
        val35,
        val34,
        val33,
        val32,
        val31,
        val30,
        val29,
        val28,
        val27,
        val26,
        val25,
        val24,
        val23,
        val22,
        val21,
        val20,
        val19,
        val18,
        val17,
        val16,
        val15,
        val14,
        val13,
        val12,
        val11,
        val10,
        val9,
        val8,
        val7,
        val6,
        val5,
        val4,
        val3,
        val2,
        val1);
  }
  template <int64_t mask>
  static Vectorized<T> blend(Vectorized<T> a, Vectorized<T> b) {
    return _mm512_mask_blend_epi8(mask, a.values, b.values);
  }
  template <typename step_t>
  static Vectorized<T> arange(
      T base = 0,
      step_t step = static_cast<step_t>(1)) {
    return Vectorized<T>(
        base,
        base + step,
        base + 2 * step,
        base + 3 * step,
        base + 4 * step,
        base + 5 * step,
        base + 6 * step,
        base + 7 * step,
        base + 8 * step,
        base + 9 * step,
        base + 10 * step,
        base + 11 * step,
        base + 12 * step,
        base + 13 * step,
        base + 14 * step,
        base + 15 * step,
        base + 16 * step,
        base + 17 * step,
        base + 18 * step,
        base + 19 * step,
        base + 20 * step,
        base + 21 * step,
        base + 22 * step,
        base + 23 * step,
        base + 24 * step,
        base + 25 * step,
        base + 26 * step,
        base + 27 * step,
        base + 28 * step,
        base + 29 * step,
        base + 30 * step,
        base + 31 * step,
        base + 32 * step,
        base + 33 * step,
        base + 34 * step,
        base + 35 * step,
        base + 36 * step,
        base + 37 * step,
        base + 38 * step,
        base + 39 * step,
        base + 40 * step,
        base + 41 * step,
        base + 42 * step,
        base + 43 * step,
        base + 44 * step,
        base + 45 * step,
        base + 46 * step,
        base + 47 * step,
        base + 48 * step,
        base + 49 * step,
        base + 50 * step,
        base + 51 * step,
        base + 52 * step,
        base + 53 * step,
        base + 54 * step,
        base + 55 * step,
        base + 56 * step,
        base + 57 * step,
        base + 58 * step,
        base + 59 * step,
        base + 60 * step,
        base + 61 * step,
        base + 62 * step,
        base + 63 * step);
  }
  static Vectorized<T> set(Vectorized<T> a, Vectorized<T> b, T count = size()) {
    switch (count) {
      case 0:
        return a;
      case 1:
        return blend<0x1>(a, b);
      case 2:
        return blend<0x3>(a, b);
      case 3:
        return blend<0x7>(a, b);
      case 4:
        return blend<0xF>(a, b);
      case 5:
        return blend<0x1F>(a, b);
      case 6:
        return blend<0x3F>(a, b);
      case 7:
        return blend<0x7F>(a, b);
      case 8:
        return blend<0xFF>(a, b);
      case 9:
        return blend<0x1FF>(a, b);
      case 10:
        return blend<0x3FF>(a, b);
      case 11:
        return blend<0x7FF>(a, b);
      case 12:
        return blend<0xFFF>(a, b);
      case 13:
        return blend<0x1FFF>(a, b);
      case 14:
        return blend<0x3FFF>(a, b);
      case 15:
        return blend<0x7FFF>(a, b);
      case 16:
        return blend<0xFFFF>(a, b);
      case 17:
        return blend<0x1FFFF>(a, b);
      case 18:
        return blend<0x3FFFF>(a, b);
      case 19:
        return blend<0x7FFFF>(a, b);
      case 20:
        return blend<0xFFFFF>(a, b);
      case 21:
        return blend<0x1FFFFF>(a, b);
      case 22:
        return blend<0x3FFFFF>(a, b);
      case 23:
        return blend<0x7FFFFF>(a, b);
      case 24:
        return blend<0xFFFFFF>(a, b);
      case 25:
        return blend<0x1FFFFFF>(a, b);
      case 26:
        return blend<0x3FFFFFF>(a, b);
      case 27:
        return blend<0x7FFFFFF>(a, b);
      case 28:
        return blend<0xFFFFFFF>(a, b);
      case 29:
        return blend<0x1FFFFFFF>(a, b);
      case 30:
        return blend<0x3FFFFFFF>(a, b);
      case 31:
        return blend<0x7FFFFFFF>(a, b);
      case 32:
        return blend<0xFFFFFFFF>(a, b);
      case 33:
        return blend<0x1FFFFFFFF>(a, b);
      case 34:
        return blend<0x3FFFFFFFF>(a, b);
      case 35:
        return blend<0x7FFFFFFFF>(a, b);
      case 36:
        return blend<0xFFFFFFFFF>(a, b);
      case 37:
        return blend<0x1FFFFFFFFF>(a, b);
      case 38:
        return blend<0x3FFFFFFFFF>(a, b);
      case 39:
        return blend<0x7FFFFFFFFF>(a, b);
      case 40:
        return blend<0xFFFFFFFFFF>(a, b);
      case 41:
        return blend<0x1FFFFFFFFFF>(a, b);
      case 42:
        return blend<0x3FFFFFFFFFF>(a, b);
      case 43:
        return blend<0x7FFFFFFFFFF>(a, b);
      case 44:
        return blend<0xFFFFFFFFFFF>(a, b);
      case 45:
        return blend<0x1FFFFFFFFFFF>(a, b);
      case 46:
        return blend<0x3FFFFFFFFFFF>(a, b);
      case 47:
        return blend<0x7FFFFFFFFFFF>(a, b);
      case 48:
        return blend<0xFFFFFFFFFFFF>(a, b);
      case 49:
        return blend<0x1FFFFFFFFFFFF>(a, b);
      case 50:
        return blend<0x3FFFFFFFFFFFF>(a, b);
      case 51:
        return blend<0x7FFFFFFFFFFFF>(a, b);
      case 52:
        return blend<0xFFFFFFFFFFFFF>(a, b);
      case 53:
        return blend<0x1FFFFFFFFFFFFF>(a, b);
      case 54:
        return blend<0x3FFFFFFFFFFFFF>(a, b);
      case 55:
        return blend<0x7FFFFFFFFFFFFF>(a, b);
      case 56:
        return blend<0xFFFFFFFFFFFFFF>(a, b);
      case 57:
        return blend<0x1FFFFFFFFFFFFFF>(a, b);
      case 58:
        return blend<0x3FFFFFFFFFFFFFF>(a, b);
      case 59:
        return blend<0x7FFFFFFFFFFFFFF>(a, b);
      case 60:
        return blend<0xFFFFFFFFFFFFFFF>(a, b);
      case 61:
        return blend<0x1FFFFFFFFFFFFFFF>(a, b);
      case 62:
        return blend<0x3FFFFFFFFFFFFFFF>(a, b);
      case 63:
        return blend<0x7FFFFFFFFFFFFFFF>(a, b);
    }
    return b;
  }
  static Vectorized<T> loadu(const void* ptr) {
    return _mm512_loadu_si512(reinterpret_cast<const __m512i*>(ptr));
  }
  static Vectorized<T> loadu_one_fourth(const void* ptr) {
    // Fast path if only load element number of 16.
    // Note: We didn't merge it as fast path of loadu(const void* ptr, T count),
    // Because loadu(const void* ptr, T count) requires zero initialization for
    // upper 384 bits. However, by using _mm512_castsi128_si512, the upper 384
    // bits of the result are undefined.
    // TODO<leslie> We can use _mm512_zextsi128_si512 in the future,
    // since gcc 9.3 doesn't support it now.
    __m128i input_128 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(ptr));
    return _mm512_castsi128_si512(input_128);
  }
  static Vectorized<T> loadu(const void* ptr, T count) {
    if (count == size()) {
      return _mm512_loadu_si512(reinterpret_cast<const __m512i*>(ptr));
    } else if (count == 16) {
      // Fast path if only load element number of 16
      return loadu_one_fourth(ptr);
    } else {
      __mmask64 mask = (1ULL << count) - 1;
      auto ones = _mm512_set1_epi8(1);
      return _mm512_mask_loadu_epi8(ones, mask, ptr);
    }
  }
  void store(void* ptr, int count = size()) const {
    if (count == size()) {
      // ptr need not to be aligned here. See
      // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm512-storeu-si512.html
      _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values);
    } else if (count > 0) {
      if (count == 16) {
        // Fast path if only store element number of 16
        _mm_storeu_si128(
            reinterpret_cast<__m128i*>(ptr), _mm512_castsi512_si128(values));
      } else {
        __mmask64 mask = (1ULL << count) - 1;
        _mm512_mask_storeu_epi8(ptr, mask, values);
      }
    }
  }
  const T& operator[](int idx) const = delete;
  T& operator[](int idx) = delete;
  Vectorized<T> real() const {
    return *this;
  }
  Vectorized<T> imag() const {
    return _mm512_set1_epi8(0);
  }
  Vectorized<T> conj() const {
    return *this;
  }
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free