Home / Class/ Vectorized16 Class — pytorch Architecture

Vectorized16 Class — pytorch Architecture

Architecture documentation for the Vectorized16 class in vec512_bfloat16.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h lines 181–811

class Vectorized16 {
  static_assert(
      is_reduced_floating_point_v<T>,
      "Support only float16 and bfloat16.");

 private:
  __m512i values;

 public:
  using value_type = uint16_t;
  using size_type = int;
  static constexpr size_type size() {
    return 32;
  }
  Vectorized16() {
    values = _mm512_setzero_si512();
  }
  Vectorized16(__m512i v) : values(v) {}
  Vectorized16(T val) {
    value_type uw = val.x;
    values = _mm512_set1_epi16(uw);
  }
  Vectorized16(
      T val1,
      T val2,
      T val3,
      T val4,
      T val5,
      T val6,
      T val7,
      T val8,
      T val9,
      T val10,
      T val11,
      T val12,
      T val13,
      T val14,
      T val15,
      T val16,
      T val17,
      T val18,
      T val19,
      T val20,
      T val21,
      T val22,
      T val23,
      T val24,
      T val25,
      T val26,
      T val27,
      T val28,
      T val29,
      T val30,
      T val31,
      T val32) {
    values = _mm512_set_epi16(
        val32.x,
        val31.x,
        val30.x,
        val29.x,
        val28.x,
        val27.x,
        val26.x,
        val25.x,
        val24.x,
        val23.x,
        val22.x,
        val21.x,
        val20.x,
        val19.x,
        val18.x,
        val17.x,
        val16.x,
        val15.x,
        val14.x,
        val13.x,
        val12.x,
        val11.x,
        val10.x,
        val9.x,
        val8.x,
        val7.x,
        val6.x,
        val5.x,
        val4.x,
        val3.x,
        val2.x,
        val1.x);
  }
  operator __m512i() const {
    return values;
  }
  T& operator[](int idx) = delete;
  const T& operator[](int idx) const = delete;
  int zero_mask() const {
    // returns an integer mask where all zero elements are translated to 1-bit
    // and others are translated to 0-bit
    return _mm512_cmpeq_epi16_mask(values, _mm512_set1_epi16(0));
  }
  static Vectorized<T> loadu(const void* ptr, int16_t count = size()) {
    if (count == size())
      return _mm512_loadu_si512(reinterpret_cast<const __m512i*>(ptr));

    __mmask32 mask = (1ULL << count) - 1;
    return _mm512_maskz_loadu_epi16(mask, ptr);
  }
  void store(void* ptr, int count = size()) const {
    if (count == size()) {
      _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values);
    } else if (count > 0) {
      __mmask32 mask = (1ULL << count) - 1;
      _mm512_mask_storeu_epi16(ptr, mask, values);
    }
  }
  template <int64_t mask>
  static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b) {
    return _mm512_mask_blend_epi16(mask, a.values, b.values);
  }
  static Vectorized<T> blendv(
      const Vectorized<T>& a,
      const Vectorized<T>& b,
      const Vectorized<T>& mask) {
    auto all_ones = _mm512_set1_epi16(0xFFFF);
    auto mask_ = _mm512_cmp_epi16_mask(mask, all_ones, _MM_CMPINT_EQ);
    return _mm512_mask_blend_epi16(mask_, a.values, b.values);
  }
  template <typename step_t>
  static Vectorized<T> arange(
      T base = 0.f,
      step_t step = static_cast<step_t>(1)) {
    return Vectorized<T>(
        base,
        base + step,
        base + 2 * step,
        base + 3 * step,
        base + 4 * step,
        base + 5 * step,
        base + 6 * step,
        base + 7 * step,
        base + 8 * step,
        base + 9 * step,
        base + 10 * step,
        base + 11 * step,
        base + 12 * step,
        base + 13 * step,
        base + 14 * step,
        base + 15 * step,
        base + 16 * step,
        base + 17 * step,
        base + 18 * step,
        base + 19 * step,
        base + 20 * step,
        base + 21 * step,
        base + 22 * step,
        base + 23 * step,
        base + 24 * step,
        base + 25 * step,
        base + 26 * step,
        base + 27 * step,
        base + 28 * step,
        base + 29 * step,
        base + 30 * step,
        base + 31 * step);
  }
  static Vectorized<T> set(
      const Vectorized<T>& a,
      const Vectorized<T>& b,
      int64_t count = size()) {
    switch (count) {
      case 0:
        return a;
      case 1:
        return blend<1>(a, b);
      case 2:
        return blend<3>(a, b);
      case 3:
        return blend<7>(a, b);
      case 4:
        return blend<15>(a, b);
      case 5:
        return blend<31>(a, b);
      case 6:
        return blend<63>(a, b);
      case 7:
        return blend<127>(a, b);
      case 8:
        return blend<255>(a, b);
      case 9:
        return blend<511>(a, b);
      case 10:
        return blend<1023>(a, b);
      case 11:
        return blend<2047>(a, b);
      case 12:
        return blend<4095>(a, b);
      case 13:
        return blend<8191>(a, b);
      case 14:
        return blend<16383>(a, b);
      case 15:
        return blend<32767>(a, b);
      case 16:
        return blend<65535>(a, b);
      case 17:
        return blend<131071>(a, b);
      case 18:
        return blend<262143>(a, b);
      case 19:
        return blend<524287>(a, b);
      case 20:
        return blend<1048575>(a, b);
      case 21:
        return blend<2097151>(a, b);
      case 22:
        return blend<4194303>(a, b);
      case 23:
        return blend<8388607>(a, b);
      case 24:
        return blend<16777215>(a, b);
      case 25:
        return blend<33554431>(a, b);
      case 26:
        return blend<67108863>(a, b);
      case 27:
        return blend<134217727>(a, b);
      case 28:
        return blend<268435455>(a, b);
      case 29:
        return blend<536870911>(a, b);
      case 30:
        return blend<1073741823>(a, b);
      case 31:
        return blend<2147483647>(a, b);
    }
    return b;
  }
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wignored-qualifiers"

  Vectorized<T> map(SLEEF_CONST __m512 (*SLEEF_CONST_OLD vop)(__m512)) const {
    __m512 lo, hi;
    cvt_to_fp32<T>(values, lo, hi);
    const auto o1 = vop(lo);
    const auto o2 = vop(hi);
    return cvt_from_fp32<T>(o1, o2);
  }
  Vectorized<T> isnan() const {
    __m512 lo, hi;
    cvt_to_fp32<T>(values, lo, hi);
    __mmask16 lo_mask, hi_mask;
    __m512 zero = _mm512_set1_ps(0.0);
    __m512i zeroi = _mm512_castps_si512(zero);
    lo_mask = _mm512_cmp_ps_mask(lo, zero, _CMP_UNORD_Q);
    lo = _mm512_castsi512_ps(
        _mm512_mask_set1_epi32(zeroi, lo_mask, 0xFFFF'FFFF));
    hi_mask = _mm512_cmp_ps_mask(hi, zero, _CMP_UNORD_Q);
    hi = _mm512_castsi512_ps(
        _mm512_mask_set1_epi32(zeroi, hi_mask, 0xFFFF'FFFF));
    return merge_compare_result(lo, hi);
  }
#pragma clang diagnostic pop
  Vectorized<T> abs() const {
    return _mm512_andnot_si512(_mm512_set1_epi16(0x8000), values);
  }
  Vectorized<T> angle() const {
    __m512 lo, hi;
    cvt_to_fp32<T>(values, lo, hi);
    auto angle_lambda = [](__m512 values) {
      const auto zero_vec = _mm512_set1_ps(0.f);
      const auto nan_vec = _mm512_set1_ps(NAN);
      const auto not_nan_mask = _mm512_cmp_ps_mask(values, values, _CMP_EQ_OQ);
      const auto non_nan_mask_vec = _mm512_mask_set1_epi32(
          _mm512_castps_si512(zero_vec), not_nan_mask, 0xFFFFFFFF);
      const auto nan_mask = _mm512_cmp_ps_mask(
          _mm512_castsi512_ps(non_nan_mask_vec), zero_vec, _CMP_EQ_OQ);
      const auto pi = _mm512_set1_ps(c10::pi<float>);

      const auto neg_mask = _mm512_cmp_ps_mask(values, zero_vec, _CMP_LT_OQ);
      auto angle = _mm512_mask_blend_ps(neg_mask, zero_vec, pi);
      angle = _mm512_mask_blend_ps(nan_mask, angle, nan_vec);
      return angle;
    };
    auto o1 = angle_lambda(lo);
    auto o2 = angle_lambda(hi);
    return cvt_from_fp32<T>(o1, o2);
  }
  Vectorized<T> real() const {
    return *this;
  }
  Vectorized<T> imag() const {
    return _mm512_set1_epi16(0);
  }
  Vectorized<T> conj() const {
    return *this;
  }
  Vectorized<T> acos() const {
    return map(Sleef_acosf16_u10);
  }
  Vectorized<T> acosh() const {
    return map(Sleef_acoshf16_u10);
  }
  Vectorized<T> asin() const {
    return map(Sleef_asinf16_u10);
  }
  Vectorized<T> asinh() const {
    return map(Sleef_asinhf16_u10);
  }
  Vectorized<T> atan() const {
    return map(Sleef_atanf16_u10);
  }
  Vectorized<T> atanh() const {
    return map(Sleef_atanhf16_u10);
  }
  Vectorized<T> atan2(const Vectorized<T>& b) const {
    __m512 lo, hi;
    __m512 b1, b2;
    cvt_to_fp32<T>(values, lo, hi);
    cvt_to_fp32<T>(b.values, b1, b2);
    auto o1 = Sleef_atan2f16_u10(lo, b1);
    auto o2 = Sleef_atan2f16_u10(hi, b2);
    return cvt_from_fp32<T>(o1, o2);
  }
  Vectorized<T> copysign(const Vectorized<T>& sign) const {
    // copy sign bit (0x8000) from sign and remaining bits from values
    __m512i mask_value = _mm512_set1_epi32(~0x80008000);
    __m512i mask_signbit = _mm512_set1_epi32(0x80008000);
    return Vectorized<T>(_mm512_or_si512(
        _mm512_and_si512(values, mask_value),
        _mm512_and_si512(sign, mask_signbit)));
  }
  Vectorized<T> erf() const {
    return map(Sleef_erff16_u10);
  }
  Vectorized<T> erfc() const {
    return map(Sleef_erfcf16_u15);
  }
  Vectorized<T> erfinv() const {
    __m512 lo, hi;
    cvt_to_fp32<T>(values, lo, hi);
    __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
    _mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
    _mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
    for (int64_t i = 0; i < size() / 2; i++) {
      tmp1[i] = calc_erfinv(tmp1[i]);
      tmp2[i] = calc_erfinv(tmp2[i]);
    }
    auto o1 = _mm512_loadu_ps(tmp1);
    auto o2 = _mm512_loadu_ps(tmp2);
    return cvt_from_fp32<T>(o1, o2);
  }
  Vectorized<T> exp() const {
    return map(Sleef_expf16_u10);
  }
  Vectorized<T> exp2() const {
    return map(Sleef_exp2f16_u10);
  }
  Vectorized<T> expm1() const {
    return map(Sleef_expm1f16_u10);
  }
  Vectorized<T> fexp_u20() const {
    return exp();
  }
  Vectorized<T> exp_u20() const {
    return exp();
  }
  Vectorized<T> fmod(const Vectorized<T>& q) const {
    __m512 x_lo, x_hi;
    cvt_to_fp32<T>(values, x_lo, x_hi);
    __m512 q_lo, q_hi;
    cvtbf16_fp32(q.values, q_lo, q_hi);
    auto o1 = Sleef_fmodf16(x_lo, q_lo);
    auto o2 = Sleef_fmodf16(x_hi, q_hi);
    return cvt_from_fp32<T>(o1, o2);
  }
  Vectorized<T> hypot(const Vectorized<T>& b) const {
    __m512 lo, hi;
    __m512 b1, b2;
    cvt_to_fp32<T>(values, lo, hi);
    cvt_to_fp32<T>(b.values, b1, b2);
    auto o1 = Sleef_hypotf16_u05(lo, b1);
    auto o2 = Sleef_hypotf16_u05(hi, b2);
    return cvt_from_fp32<T>(o1, o2);
  }
  Vectorized<T> i0() const {
    __m512 lo, hi;
    cvt_to_fp32<T>(values, lo, hi);
    __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
    _mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
    _mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
    for (int64_t i = 0; i < size() / 2; i++) {
      tmp1[i] = calc_i0(tmp1[i]);
      tmp2[i] = calc_i0(tmp2[i]);
    }
    auto o1 = _mm512_loadu_ps(tmp1);
    auto o2 = _mm512_loadu_ps(tmp2);
    return cvt_from_fp32<T>(o1, o2);
  }
  Vectorized<T> i0e() const {
    __m512 lo, hi;
    cvt_to_fp32<T>(values, lo, hi);
    constexpr auto sz = size();
    __at_align__ float tmp1[sz / 2], tmp2[sz / 2];
    _mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
    _mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);

    for (auto i = decltype(sz){0}; i < sz / 2; i++) {
      tmp1[i] = calc_i0e(tmp1[i]);
      tmp2[i] = calc_i0e(tmp2[i]);
    }
    const auto o1 = _mm512_loadu_ps(tmp1);
    const auto o2 = _mm512_loadu_ps(tmp2);
    return cvt_from_fp32<T>(o1, o2);
  }
  Vectorized<T> digamma() const {
    __m512 lo, hi;
    cvt_to_fp32<T>(values, lo, hi);
    constexpr auto sz = size();
    __at_align__ float tmp1[sz / 2], tmp2[sz / 2];
    _mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
    _mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);

    for (auto i = decltype(sz){0}; i < sz / 2; i++) {
      tmp1[i] = calc_digamma(tmp1[i]);
      tmp2[i] = calc_digamma(tmp2[i]);
    }
    const auto o1 = _mm512_loadu_ps(tmp1);
    const auto o2 = _mm512_loadu_ps(tmp2);
    return cvt_from_fp32<T>(o1, o2);
  }
  Vectorized<T> igamma(const Vectorized<T>& x) const {
    __m512 lo, hi;
    __m512 xlo, xhi;
    cvt_to_fp32<T>(values, lo, hi);
    cvt_to_fp32<T>(x.values, xlo, xhi);
    __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
    _mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
    _mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
    __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2];
    _mm512_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo);
    _mm512_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi);
    for (int64_t i = 0; i < size() / 2; ++i) {
      tmp1[i] = calc_igamma(tmp1[i], tmpx1[i]);
      tmp2[i] = calc_igamma(tmp2[i], tmpx2[i]);
    }
    auto o1 = _mm512_loadu_ps(tmp1);
    auto o2 = _mm512_loadu_ps(tmp2);
    return cvt_from_fp32<T>(o1, o2);
  }

  Vectorized<T> igammac(const Vectorized<T>& x) const {
    __m512 lo, hi;
    __m512 xlo, xhi;
    cvt_to_fp32<T>(values, lo, hi);
    cvt_to_fp32<T>(x.values, xlo, xhi);
    __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
    _mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
    _mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
    __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2];
    _mm512_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo);
    _mm512_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi);
    for (int64_t i = 0; i < size() / 2; ++i) {
      tmp1[i] = calc_igammac(tmp1[i], tmpx1[i]);
      tmp2[i] = calc_igammac(tmp2[i], tmpx2[i]);
    }
    auto o1 = _mm512_loadu_ps(tmp1);
    auto o2 = _mm512_loadu_ps(tmp2);
    return cvt_from_fp32<T>(o1, o2);
  }
  Vectorized<T> log() const {
    return map(Sleef_logf16_u10);
  }
  Vectorized<T> log2() const {
    return map(Sleef_log2f16_u10);
  }
  Vectorized<T> log10() const {
    return map(Sleef_log10f16_u10);
  }
  Vectorized<T> log1p() const {
    return map(Sleef_log1pf16_u10);
  }
  Vectorized<T> sin() const {
    return map(Sleef_sinf16_u10);
  }
  Vectorized<T> sinh() const {
    return map(Sleef_sinhf16_u10);
  }
  Vectorized<T> cos() const {
    return map(Sleef_cosf16_u10);
  }
  Vectorized<T> cosh() const {
    return map(Sleef_coshf16_u10);
  }
  Vectorized<T> ceil() const {
    __m512 lo, hi;
    cvt_to_fp32<T>(values, lo, hi);
    auto o1 = _mm512_ceil_ps(lo);
    auto o2 = _mm512_ceil_ps(hi);
    return cvt_from_fp32<T>(o1, o2);
  }
  Vectorized<T> floor() const {
    __m512 lo, hi;
    cvt_to_fp32<T>(values, lo, hi);
    auto o1 = _mm512_floor_ps(lo);
    auto o2 = _mm512_floor_ps(hi);
    return cvt_from_fp32<T>(o1, o2);
  }
  Vectorized<T> neg() const {
    return _mm512_xor_si512(values, _mm512_set1_epi16(0x8000));
  }
  Vectorized<T> round() const {
    __m512 lo, hi;
    cvt_to_fp32<T>(values, lo, hi);
    auto o1 = _mm512_roundscale_ps(
        lo, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
    auto o2 = _mm512_roundscale_ps(
        hi, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
    return cvt_from_fp32<T>(o1, o2);
  }
  Vectorized<T> tan() const {
    return map(Sleef_tanf16_u10);
  }
  Vectorized<T> tanh() const {
    return map(Sleef_tanhf16_u10);
  }
  Vectorized<T> trunc() const {
    __m512 lo, hi;
    cvt_to_fp32<T>(values, lo, hi);
    auto o1 =
        _mm512_roundscale_ps(lo, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
    auto o2 =
        _mm512_roundscale_ps(hi, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
    return cvt_from_fp32<T>(o1, o2);
  }
  Vectorized<T> lgamma() const {
    return map(Sleef_lgammaf16_u10);
  }
  Vectorized<T> sqrt() const {
    __m512 lo, hi;
    cvt_to_fp32<T>(values, lo, hi);
    auto o1 = _mm512_sqrt_ps(lo);
    auto o2 = _mm512_sqrt_ps(hi);
    return cvt_from_fp32<T>(o1, o2);
  }
  Vectorized<T> reciprocal() const {
    __m512 lo, hi;
    cvt_to_fp32<T>(values, lo, hi);
    auto ones = _mm512_set1_ps(1);
    auto o1 = _mm512_div_ps(ones, lo);
    auto o2 = _mm512_div_ps(ones, hi);
    return cvt_from_fp32<T>(o1, o2);
  }
  Vectorized<T> rsqrt() const {
    __m512 lo, hi;
    cvt_to_fp32<T>(values, lo, hi);
    auto ones = _mm512_set1_ps(1);
    auto o1 = _mm512_div_ps(ones, _mm512_sqrt_ps(lo));
    auto o2 = _mm512_div_ps(ones, _mm512_sqrt_ps(hi));
    return cvt_from_fp32<T>(o1, o2);
  }
  Vectorized<T> pow(const Vectorized<T>& b) const {
    __m512 lo, hi;
    __m512 b1, b2;
    cvt_to_fp32<T>(values, lo, hi);
    cvt_to_fp32<T>(b.values, b1, b2);
    auto o1 = Sleef_powf16_u10(lo, b1);
    auto o2 = Sleef_powf16_u10(hi, b2);
    return cvt_from_fp32<T>(o1, o2);
  }

 private:
  template <typename Op, typename VectorizedType>
  Vectorized<T> inline binary_compare(const VectorizedType& b, Op op) const {
    __m512 a_lo, a_hi;
    __m512 b_lo, b_hi;
    cvt_to_fp32<T>(values, a_lo, a_hi);
    cvt_to_fp32<T>(b.values, b_lo, b_hi);
    auto o1 = op(a_lo, b_lo);
    auto o2 = op(a_hi, b_hi);
    return cvt_from_fp32<T, /*is_compare_op*/ true>(o1, o2);
  }

 public:
  Vectorized<T> inline operator>(const Vectorized<T>& other) const {
    return binary_compare(other, [](__m512 x, __m512 y) {
      auto zero_vec = _mm512_set1_epi32(0);
      auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GT_OQ);
      return _mm512_castsi512_ps(
          _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
    });
  }
  Vectorized<T> inline operator<(const Vectorized<T>& other) const {
    return binary_compare(other, [](__m512 x, __m512 y) {
      auto zero_vec = _mm512_set1_epi32(0);
      auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LT_OQ);
      return _mm512_castsi512_ps(
          _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
    });
  }
  Vectorized<T> inline operator>=(const Vectorized<T>& other) const {
    return binary_compare(other, [](__m512 x, __m512 y) {
      auto zero_vec = _mm512_set1_epi32(0);
      auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ);
      return _mm512_castsi512_ps(
          _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
    });
  }
  Vectorized<T> inline operator<=(const Vectorized<T>& other) const {
    return binary_compare(other, [](__m512 x, __m512 y) {
      auto zero_vec = _mm512_set1_epi32(0);
      auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LE_OQ);
      return _mm512_castsi512_ps(
          _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
    });
  }
  Vectorized<T> inline operator==(const Vectorized16<T>& other) const {
    return binary_compare(other, [](__m512 x, __m512 y) {
      auto zero_vec = _mm512_set1_epi32(0);
      auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_EQ_OQ);
      return _mm512_castsi512_ps(
          _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
    });
  }
  Vectorized<T> inline operator!=(const Vectorized16<T>& other) const {
    return binary_compare(other, [](__m512 x, __m512 y) {
      auto zero_vec = _mm512_set1_epi32(0);
      auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_NEQ_UQ);
      return _mm512_castsi512_ps(
          _mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
    });
  }
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free