is_reduced_floating_point_v Class — pytorch Architecture
Architecture documentation for the is_reduced_floating_point_v class in vec512_bfloat16.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h lines 180–811
template <typename T>
class Vectorized16 {
static_assert(
is_reduced_floating_point_v<T>,
"Support only float16 and bfloat16.");
private:
__m512i values;
public:
using value_type = uint16_t;
using size_type = int;
static constexpr size_type size() {
return 32;
}
Vectorized16() {
values = _mm512_setzero_si512();
}
Vectorized16(__m512i v) : values(v) {}
Vectorized16(T val) {
value_type uw = val.x;
values = _mm512_set1_epi16(uw);
}
Vectorized16(
T val1,
T val2,
T val3,
T val4,
T val5,
T val6,
T val7,
T val8,
T val9,
T val10,
T val11,
T val12,
T val13,
T val14,
T val15,
T val16,
T val17,
T val18,
T val19,
T val20,
T val21,
T val22,
T val23,
T val24,
T val25,
T val26,
T val27,
T val28,
T val29,
T val30,
T val31,
T val32) {
values = _mm512_set_epi16(
val32.x,
val31.x,
val30.x,
val29.x,
val28.x,
val27.x,
val26.x,
val25.x,
val24.x,
val23.x,
val22.x,
val21.x,
val20.x,
val19.x,
val18.x,
val17.x,
val16.x,
val15.x,
val14.x,
val13.x,
val12.x,
val11.x,
val10.x,
val9.x,
val8.x,
val7.x,
val6.x,
val5.x,
val4.x,
val3.x,
val2.x,
val1.x);
}
operator __m512i() const {
return values;
}
T& operator[](int idx) = delete;
const T& operator[](int idx) const = delete;
int zero_mask() const {
// returns an integer mask where all zero elements are translated to 1-bit
// and others are translated to 0-bit
return _mm512_cmpeq_epi16_mask(values, _mm512_set1_epi16(0));
}
static Vectorized<T> loadu(const void* ptr, int16_t count = size()) {
if (count == size())
return _mm512_loadu_si512(reinterpret_cast<const __m512i*>(ptr));
__mmask32 mask = (1ULL << count) - 1;
return _mm512_maskz_loadu_epi16(mask, ptr);
}
void store(void* ptr, int count = size()) const {
if (count == size()) {
_mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values);
} else if (count > 0) {
__mmask32 mask = (1ULL << count) - 1;
_mm512_mask_storeu_epi16(ptr, mask, values);
}
}
template <int64_t mask>
static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b) {
return _mm512_mask_blend_epi16(mask, a.values, b.values);
}
static Vectorized<T> blendv(
const Vectorized<T>& a,
const Vectorized<T>& b,
const Vectorized<T>& mask) {
auto all_ones = _mm512_set1_epi16(0xFFFF);
auto mask_ = _mm512_cmp_epi16_mask(mask, all_ones, _MM_CMPINT_EQ);
return _mm512_mask_blend_epi16(mask_, a.values, b.values);
}
template <typename step_t>
static Vectorized<T> arange(
T base = 0.f,
step_t step = static_cast<step_t>(1)) {
return Vectorized<T>(
base,
base + step,
base + 2 * step,
base + 3 * step,
base + 4 * step,
base + 5 * step,
base + 6 * step,
base + 7 * step,
base + 8 * step,
base + 9 * step,
base + 10 * step,
base + 11 * step,
base + 12 * step,
base + 13 * step,
base + 14 * step,
base + 15 * step,
base + 16 * step,
base + 17 * step,
base + 18 * step,
base + 19 * step,
base + 20 * step,
base + 21 * step,
base + 22 * step,
base + 23 * step,
base + 24 * step,
base + 25 * step,
base + 26 * step,
base + 27 * step,
base + 28 * step,
base + 29 * step,
base + 30 * step,
base + 31 * step);
}
static Vectorized<T> set(
const Vectorized<T>& a,
const Vectorized<T>& b,
int64_t count = size()) {
switch (count) {
case 0:
return a;
case 1:
return blend<1>(a, b);
case 2:
return blend<3>(a, b);
case 3:
return blend<7>(a, b);
case 4:
return blend<15>(a, b);
case 5:
return blend<31>(a, b);
case 6:
return blend<63>(a, b);
case 7:
return blend<127>(a, b);
case 8:
return blend<255>(a, b);
case 9:
return blend<511>(a, b);
case 10:
return blend<1023>(a, b);
case 11:
return blend<2047>(a, b);
case 12:
return blend<4095>(a, b);
case 13:
return blend<8191>(a, b);
case 14:
return blend<16383>(a, b);
case 15:
return blend<32767>(a, b);
case 16:
return blend<65535>(a, b);
case 17:
return blend<131071>(a, b);
case 18:
return blend<262143>(a, b);
case 19:
return blend<524287>(a, b);
case 20:
return blend<1048575>(a, b);
case 21:
return blend<2097151>(a, b);
case 22:
return blend<4194303>(a, b);
case 23:
return blend<8388607>(a, b);
case 24:
return blend<16777215>(a, b);
case 25:
return blend<33554431>(a, b);
case 26:
return blend<67108863>(a, b);
case 27:
return blend<134217727>(a, b);
case 28:
return blend<268435455>(a, b);
case 29:
return blend<536870911>(a, b);
case 30:
return blend<1073741823>(a, b);
case 31:
return blend<2147483647>(a, b);
}
return b;
}
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wignored-qualifiers"
Vectorized<T> map(SLEEF_CONST __m512 (*SLEEF_CONST_OLD vop)(__m512)) const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
const auto o1 = vop(lo);
const auto o2 = vop(hi);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> isnan() const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
__mmask16 lo_mask, hi_mask;
__m512 zero = _mm512_set1_ps(0.0);
__m512i zeroi = _mm512_castps_si512(zero);
lo_mask = _mm512_cmp_ps_mask(lo, zero, _CMP_UNORD_Q);
lo = _mm512_castsi512_ps(
_mm512_mask_set1_epi32(zeroi, lo_mask, 0xFFFF'FFFF));
hi_mask = _mm512_cmp_ps_mask(hi, zero, _CMP_UNORD_Q);
hi = _mm512_castsi512_ps(
_mm512_mask_set1_epi32(zeroi, hi_mask, 0xFFFF'FFFF));
return merge_compare_result(lo, hi);
}
#pragma clang diagnostic pop
Vectorized<T> abs() const {
return _mm512_andnot_si512(_mm512_set1_epi16(0x8000), values);
}
Vectorized<T> angle() const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
auto angle_lambda = [](__m512 values) {
const auto zero_vec = _mm512_set1_ps(0.f);
const auto nan_vec = _mm512_set1_ps(NAN);
const auto not_nan_mask = _mm512_cmp_ps_mask(values, values, _CMP_EQ_OQ);
const auto non_nan_mask_vec = _mm512_mask_set1_epi32(
_mm512_castps_si512(zero_vec), not_nan_mask, 0xFFFFFFFF);
const auto nan_mask = _mm512_cmp_ps_mask(
_mm512_castsi512_ps(non_nan_mask_vec), zero_vec, _CMP_EQ_OQ);
const auto pi = _mm512_set1_ps(c10::pi<float>);
const auto neg_mask = _mm512_cmp_ps_mask(values, zero_vec, _CMP_LT_OQ);
auto angle = _mm512_mask_blend_ps(neg_mask, zero_vec, pi);
angle = _mm512_mask_blend_ps(nan_mask, angle, nan_vec);
return angle;
};
auto o1 = angle_lambda(lo);
auto o2 = angle_lambda(hi);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> real() const {
return *this;
}
Vectorized<T> imag() const {
return _mm512_set1_epi16(0);
}
Vectorized<T> conj() const {
return *this;
}
Vectorized<T> acos() const {
return map(Sleef_acosf16_u10);
}
Vectorized<T> acosh() const {
return map(Sleef_acoshf16_u10);
}
Vectorized<T> asin() const {
return map(Sleef_asinf16_u10);
}
Vectorized<T> asinh() const {
return map(Sleef_asinhf16_u10);
}
Vectorized<T> atan() const {
return map(Sleef_atanf16_u10);
}
Vectorized<T> atanh() const {
return map(Sleef_atanhf16_u10);
}
Vectorized<T> atan2(const Vectorized<T>& b) const {
__m512 lo, hi;
__m512 b1, b2;
cvt_to_fp32<T>(values, lo, hi);
cvt_to_fp32<T>(b.values, b1, b2);
auto o1 = Sleef_atan2f16_u10(lo, b1);
auto o2 = Sleef_atan2f16_u10(hi, b2);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> copysign(const Vectorized<T>& sign) const {
// copy sign bit (0x8000) from sign and remaining bits from values
__m512i mask_value = _mm512_set1_epi32(~0x80008000);
__m512i mask_signbit = _mm512_set1_epi32(0x80008000);
return Vectorized<T>(_mm512_or_si512(
_mm512_and_si512(values, mask_value),
_mm512_and_si512(sign, mask_signbit)));
}
Vectorized<T> erf() const {
return map(Sleef_erff16_u10);
}
Vectorized<T> erfc() const {
return map(Sleef_erfcf16_u15);
}
Vectorized<T> erfinv() const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
__at_align__ float tmp1[size() / 2], tmp2[size() / 2];
_mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
_mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
for (int64_t i = 0; i < size() / 2; i++) {
tmp1[i] = calc_erfinv(tmp1[i]);
tmp2[i] = calc_erfinv(tmp2[i]);
}
auto o1 = _mm512_loadu_ps(tmp1);
auto o2 = _mm512_loadu_ps(tmp2);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> exp() const {
return map(Sleef_expf16_u10);
}
Vectorized<T> exp2() const {
return map(Sleef_exp2f16_u10);
}
Vectorized<T> expm1() const {
return map(Sleef_expm1f16_u10);
}
Vectorized<T> fexp_u20() const {
return exp();
}
Vectorized<T> exp_u20() const {
return exp();
}
Vectorized<T> fmod(const Vectorized<T>& q) const {
__m512 x_lo, x_hi;
cvt_to_fp32<T>(values, x_lo, x_hi);
__m512 q_lo, q_hi;
cvtbf16_fp32(q.values, q_lo, q_hi);
auto o1 = Sleef_fmodf16(x_lo, q_lo);
auto o2 = Sleef_fmodf16(x_hi, q_hi);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> hypot(const Vectorized<T>& b) const {
__m512 lo, hi;
__m512 b1, b2;
cvt_to_fp32<T>(values, lo, hi);
cvt_to_fp32<T>(b.values, b1, b2);
auto o1 = Sleef_hypotf16_u05(lo, b1);
auto o2 = Sleef_hypotf16_u05(hi, b2);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> i0() const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
__at_align__ float tmp1[size() / 2], tmp2[size() / 2];
_mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
_mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
for (int64_t i = 0; i < size() / 2; i++) {
tmp1[i] = calc_i0(tmp1[i]);
tmp2[i] = calc_i0(tmp2[i]);
}
auto o1 = _mm512_loadu_ps(tmp1);
auto o2 = _mm512_loadu_ps(tmp2);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> i0e() const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
constexpr auto sz = size();
__at_align__ float tmp1[sz / 2], tmp2[sz / 2];
_mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
_mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
for (auto i = decltype(sz){0}; i < sz / 2; i++) {
tmp1[i] = calc_i0e(tmp1[i]);
tmp2[i] = calc_i0e(tmp2[i]);
}
const auto o1 = _mm512_loadu_ps(tmp1);
const auto o2 = _mm512_loadu_ps(tmp2);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> digamma() const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
constexpr auto sz = size();
__at_align__ float tmp1[sz / 2], tmp2[sz / 2];
_mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
_mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
for (auto i = decltype(sz){0}; i < sz / 2; i++) {
tmp1[i] = calc_digamma(tmp1[i]);
tmp2[i] = calc_digamma(tmp2[i]);
}
const auto o1 = _mm512_loadu_ps(tmp1);
const auto o2 = _mm512_loadu_ps(tmp2);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> igamma(const Vectorized<T>& x) const {
__m512 lo, hi;
__m512 xlo, xhi;
cvt_to_fp32<T>(values, lo, hi);
cvt_to_fp32<T>(x.values, xlo, xhi);
__at_align__ float tmp1[size() / 2], tmp2[size() / 2];
_mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
_mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
__at_align__ float tmpx1[size() / 2], tmpx2[size() / 2];
_mm512_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo);
_mm512_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi);
for (int64_t i = 0; i < size() / 2; ++i) {
tmp1[i] = calc_igamma(tmp1[i], tmpx1[i]);
tmp2[i] = calc_igamma(tmp2[i], tmpx2[i]);
}
auto o1 = _mm512_loadu_ps(tmp1);
auto o2 = _mm512_loadu_ps(tmp2);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> igammac(const Vectorized<T>& x) const {
__m512 lo, hi;
__m512 xlo, xhi;
cvt_to_fp32<T>(values, lo, hi);
cvt_to_fp32<T>(x.values, xlo, xhi);
__at_align__ float tmp1[size() / 2], tmp2[size() / 2];
_mm512_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
_mm512_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
__at_align__ float tmpx1[size() / 2], tmpx2[size() / 2];
_mm512_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo);
_mm512_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi);
for (int64_t i = 0; i < size() / 2; ++i) {
tmp1[i] = calc_igammac(tmp1[i], tmpx1[i]);
tmp2[i] = calc_igammac(tmp2[i], tmpx2[i]);
}
auto o1 = _mm512_loadu_ps(tmp1);
auto o2 = _mm512_loadu_ps(tmp2);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> log() const {
return map(Sleef_logf16_u10);
}
Vectorized<T> log2() const {
return map(Sleef_log2f16_u10);
}
Vectorized<T> log10() const {
return map(Sleef_log10f16_u10);
}
Vectorized<T> log1p() const {
return map(Sleef_log1pf16_u10);
}
Vectorized<T> sin() const {
return map(Sleef_sinf16_u10);
}
Vectorized<T> sinh() const {
return map(Sleef_sinhf16_u10);
}
Vectorized<T> cos() const {
return map(Sleef_cosf16_u10);
}
Vectorized<T> cosh() const {
return map(Sleef_coshf16_u10);
}
Vectorized<T> ceil() const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
auto o1 = _mm512_ceil_ps(lo);
auto o2 = _mm512_ceil_ps(hi);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> floor() const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
auto o1 = _mm512_floor_ps(lo);
auto o2 = _mm512_floor_ps(hi);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> neg() const {
return _mm512_xor_si512(values, _mm512_set1_epi16(0x8000));
}
Vectorized<T> round() const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
auto o1 = _mm512_roundscale_ps(
lo, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
auto o2 = _mm512_roundscale_ps(
hi, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> tan() const {
return map(Sleef_tanf16_u10);
}
Vectorized<T> tanh() const {
return map(Sleef_tanhf16_u10);
}
Vectorized<T> trunc() const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
auto o1 =
_mm512_roundscale_ps(lo, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
auto o2 =
_mm512_roundscale_ps(hi, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> lgamma() const {
return map(Sleef_lgammaf16_u10);
}
Vectorized<T> sqrt() const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
auto o1 = _mm512_sqrt_ps(lo);
auto o2 = _mm512_sqrt_ps(hi);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> reciprocal() const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
auto ones = _mm512_set1_ps(1);
auto o1 = _mm512_div_ps(ones, lo);
auto o2 = _mm512_div_ps(ones, hi);
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> rsqrt() const {
__m512 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
auto ones = _mm512_set1_ps(1);
auto o1 = _mm512_div_ps(ones, _mm512_sqrt_ps(lo));
auto o2 = _mm512_div_ps(ones, _mm512_sqrt_ps(hi));
return cvt_from_fp32<T>(o1, o2);
}
Vectorized<T> pow(const Vectorized<T>& b) const {
__m512 lo, hi;
__m512 b1, b2;
cvt_to_fp32<T>(values, lo, hi);
cvt_to_fp32<T>(b.values, b1, b2);
auto o1 = Sleef_powf16_u10(lo, b1);
auto o2 = Sleef_powf16_u10(hi, b2);
return cvt_from_fp32<T>(o1, o2);
}
private:
template <typename Op, typename VectorizedType>
Vectorized<T> inline binary_compare(const VectorizedType& b, Op op) const {
__m512 a_lo, a_hi;
__m512 b_lo, b_hi;
cvt_to_fp32<T>(values, a_lo, a_hi);
cvt_to_fp32<T>(b.values, b_lo, b_hi);
auto o1 = op(a_lo, b_lo);
auto o2 = op(a_hi, b_hi);
return cvt_from_fp32<T, /*is_compare_op*/ true>(o1, o2);
}
public:
Vectorized<T> inline operator>(const Vectorized<T>& other) const {
return binary_compare(other, [](__m512 x, __m512 y) {
auto zero_vec = _mm512_set1_epi32(0);
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GT_OQ);
return _mm512_castsi512_ps(
_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
});
}
Vectorized<T> inline operator<(const Vectorized<T>& other) const {
return binary_compare(other, [](__m512 x, __m512 y) {
auto zero_vec = _mm512_set1_epi32(0);
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LT_OQ);
return _mm512_castsi512_ps(
_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
});
}
Vectorized<T> inline operator>=(const Vectorized<T>& other) const {
return binary_compare(other, [](__m512 x, __m512 y) {
auto zero_vec = _mm512_set1_epi32(0);
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ);
return _mm512_castsi512_ps(
_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
});
}
Vectorized<T> inline operator<=(const Vectorized<T>& other) const {
return binary_compare(other, [](__m512 x, __m512 y) {
auto zero_vec = _mm512_set1_epi32(0);
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LE_OQ);
return _mm512_castsi512_ps(
_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
});
}
Vectorized<T> inline operator==(const Vectorized16<T>& other) const {
return binary_compare(other, [](__m512 x, __m512 y) {
auto zero_vec = _mm512_set1_epi32(0);
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_EQ_OQ);
return _mm512_castsi512_ps(
_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
});
}
Vectorized<T> inline operator!=(const Vectorized16<T>& other) const {
return binary_compare(other, [](__m512 x, __m512 y) {
auto zero_vec = _mm512_set1_epi32(0);
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_NEQ_UQ);
return _mm512_castsi512_ps(
_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
});
}
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free