Home / Class/ Vectorizedf8 Class — pytorch Architecture

Vectorizedf8 Class — pytorch Architecture

Architecture documentation for the Vectorizedf8 class in vec512_float8.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/cpu/vec/vec512/vec512_float8.h lines 260–422

class Vectorizedf8 {
  static_assert(
      std::integral_constant < bool,
      std::is_same_v<T, at::Float8_e4m3fn> || std::is_same_v < T,
      at::Float8_e5m2 >> ::value,
      "Support only float8 e4m3.");

 private:
  __m512i values;
  template <typename Op, typename VectorizedType>
  Vectorized<T> inline binary_compare(const VectorizedType& b, Op op) const {
    __m512 a0, a1, a2, a3;
    __m512 b0, b1, b2, b3;
    __m512 o0, o1, o2, o3;
    if constexpr (std::is_same_v<T, c10::Float8_e4m3fn>) {
      cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(values, 0), a0);
      cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b.values, 0), b0);
      cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(values, 1), a1);
      cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b.values, 1), b1);
      cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(values, 2), a2);
      cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b.values, 2), b2);
      cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(values, 3), a3);
      cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b.values, 3), b3);
    } else {
      cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(values, 0), a0);
      cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b.values, 0), b0);
      cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(values, 1), a1);
      cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b.values, 1), b1);
      cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(values, 2), a2);
      cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b.values, 2), b2);
      cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(values, 3), a3);
      cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b.values, 3), b3);
    }

    o0 = op(a0, b0);
    o1 = op(a1, b1);
    o2 = op(a2, b2);
    o3 = op(a3, b3);
    __m128i o128_0, o128_1, o128_2, o128_3;
    if constexpr (std::is_same_v<T, c10::Float8_e4m3fn>) {
      o128_0 = cvtfp32_fp8e4m3(o0);
      o128_1 = cvtfp32_fp8e4m3(o1);
      o128_2 = cvtfp32_fp8e4m3(o2);
      o128_3 = cvtfp32_fp8e4m3(o3);
    } else {
      o128_0 = cvtfp32_fp8e5m2(o0);
      o128_1 = cvtfp32_fp8e5m2(o1);
      o128_2 = cvtfp32_fp8e5m2(o2);
      o128_3 = cvtfp32_fp8e5m2(o3);
    }

    __m512i result = _mm512_setzero_si512();
    result = _mm512_inserti32x4(result, o128_0, 0);
    result = _mm512_inserti32x4(result, o128_1, 1);
    result = _mm512_inserti32x4(result, o128_2, 2);
    result = _mm512_inserti32x4(result, o128_3, 3);

    return result;
  }

 public:
  using value_type = uint8_t;
  using size_type = int;
  static constexpr size_type size() {
    return 64;
  }
  Vectorizedf8() {}
  Vectorizedf8(__m512i v) : values(v) {}
  Vectorizedf8(T val) {
    value_type uw = val.x;
    values = _mm512_set1_epi8(uw);
  }
  operator __m512i() const {
    return values;
  }
  T& operator[](int idx) = delete;
  const T& operator[](int idx) const = delete;
  static Vectorized<T> loadu(const void* ptr, int16_t count = size()) {
    if (count == size()) {
      return _mm512_loadu_si512(reinterpret_cast<const __m512i*>(ptr));
    } else if (count == 16) {
      // Fast path if only load element number of 16
      __m128i input_128 =
          _mm_loadu_si128(reinterpret_cast<const __m128i*>(ptr));
      return _mm512_castsi128_si512(input_128);
    } else {
      __mmask64 mask = (1ULL << count) - 1;
      return _mm512_maskz_loadu_epi8(mask, ptr);
    }
  }
  void store(void* ptr, int count = size()) const {
    if (count == size()) {
      _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values);
    } else if (count > 0) {
      if (count == 16) {
        // Fast path if only store element number of 16
        _mm_storeu_si128(
            reinterpret_cast<__m128i*>(ptr), _mm512_castsi512_si128(values));
      } else {
        __mmask64 mask = (1ULL << count) - 1;
        _mm512_mask_storeu_epi8(ptr, mask, values);
      }
    }
  }

  Vectorized<T> abs() const {
    return _mm512_andnot_si512(_mm512_set1_epi8(0x80), values);
  }

  Vectorized<T> inline operator==(const Vectorizedf8<T>& other) const {
    return binary_compare(other, [](__m512 x, __m512 y) {
      auto zero_vec = _mm512_set1_epi32(0);
      auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_EQ_OQ);
      return _mm512_castsi512_ps(
          _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
    });
  }

  Vectorized<T> inline operator!=(const Vectorizedf8<T>& other) const {
    return binary_compare(other, [](__m512 x, __m512 y) {
      auto zero_vec = _mm512_set1_epi32(0);
      auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_NEQ_UQ);
      return _mm512_castsi512_ps(
          _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
    });
  }

  Vectorized<T> inline operator>(const Vectorizedf8<T>& other) const {
    return binary_compare(other, [](__m512 x, __m512 y) {
      auto zero_vec = _mm512_set1_epi32(0);
      auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GT_OQ);
      return _mm512_castsi512_ps(
          _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
    });
  }

  Vectorized<T> inline operator>=(const Vectorizedf8<T>& other) const {
    return binary_compare(other, [](__m512 x, __m512 y) {
      auto zero_vec = _mm512_set1_epi32(0);
      auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ);
      return _mm512_castsi512_ps(
          _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
    });
  }

  Vectorized<T> inline operator<(const Vectorizedf8<T>& other) const {
    return binary_compare(other, [](__m512 x, __m512 y) {
      auto zero_vec = _mm512_set1_epi32(0);
      auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LT_OQ);
      return _mm512_castsi512_ps(
          _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
    });
  }

  Vectorized<T> inline operator<=(const Vectorizedf8<T>& other) const {
    return binary_compare(other, [](__m512 x, __m512 y) {
      auto zero_vec = _mm512_set1_epi32(0);
      auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LE_OQ);
      return _mm512_castsi512_ps(
          _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
    });
  }
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free