Vectorizedf8 Class — pytorch Architecture
Architecture documentation for the Vectorizedf8 class in vec512_float8.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/cpu/vec/vec512/vec512_float8.h lines 260–422
class Vectorizedf8 {
static_assert(
std::integral_constant < bool,
std::is_same_v<T, at::Float8_e4m3fn> || std::is_same_v < T,
at::Float8_e5m2 >> ::value,
"Support only float8 e4m3.");
private:
__m512i values;
template <typename Op, typename VectorizedType>
Vectorized<T> inline binary_compare(const VectorizedType& b, Op op) const {
__m512 a0, a1, a2, a3;
__m512 b0, b1, b2, b3;
__m512 o0, o1, o2, o3;
if constexpr (std::is_same_v<T, c10::Float8_e4m3fn>) {
cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(values, 0), a0);
cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b.values, 0), b0);
cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(values, 1), a1);
cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b.values, 1), b1);
cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(values, 2), a2);
cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b.values, 2), b2);
cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(values, 3), a3);
cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b.values, 3), b3);
} else {
cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(values, 0), a0);
cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b.values, 0), b0);
cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(values, 1), a1);
cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b.values, 1), b1);
cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(values, 2), a2);
cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b.values, 2), b2);
cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(values, 3), a3);
cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b.values, 3), b3);
}
o0 = op(a0, b0);
o1 = op(a1, b1);
o2 = op(a2, b2);
o3 = op(a3, b3);
__m128i o128_0, o128_1, o128_2, o128_3;
if constexpr (std::is_same_v<T, c10::Float8_e4m3fn>) {
o128_0 = cvtfp32_fp8e4m3(o0);
o128_1 = cvtfp32_fp8e4m3(o1);
o128_2 = cvtfp32_fp8e4m3(o2);
o128_3 = cvtfp32_fp8e4m3(o3);
} else {
o128_0 = cvtfp32_fp8e5m2(o0);
o128_1 = cvtfp32_fp8e5m2(o1);
o128_2 = cvtfp32_fp8e5m2(o2);
o128_3 = cvtfp32_fp8e5m2(o3);
}
__m512i result = _mm512_setzero_si512();
result = _mm512_inserti32x4(result, o128_0, 0);
result = _mm512_inserti32x4(result, o128_1, 1);
result = _mm512_inserti32x4(result, o128_2, 2);
result = _mm512_inserti32x4(result, o128_3, 3);
return result;
}
public:
using value_type = uint8_t;
using size_type = int;
static constexpr size_type size() {
return 64;
}
Vectorizedf8() {}
Vectorizedf8(__m512i v) : values(v) {}
Vectorizedf8(T val) {
value_type uw = val.x;
values = _mm512_set1_epi8(uw);
}
operator __m512i() const {
return values;
}
T& operator[](int idx) = delete;
const T& operator[](int idx) const = delete;
static Vectorized<T> loadu(const void* ptr, int16_t count = size()) {
if (count == size()) {
return _mm512_loadu_si512(reinterpret_cast<const __m512i*>(ptr));
} else if (count == 16) {
// Fast path if only load element number of 16
__m128i input_128 =
_mm_loadu_si128(reinterpret_cast<const __m128i*>(ptr));
return _mm512_castsi128_si512(input_128);
} else {
__mmask64 mask = (1ULL << count) - 1;
return _mm512_maskz_loadu_epi8(mask, ptr);
}
}
void store(void* ptr, int count = size()) const {
if (count == size()) {
_mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values);
} else if (count > 0) {
if (count == 16) {
// Fast path if only store element number of 16
_mm_storeu_si128(
reinterpret_cast<__m128i*>(ptr), _mm512_castsi512_si128(values));
} else {
__mmask64 mask = (1ULL << count) - 1;
_mm512_mask_storeu_epi8(ptr, mask, values);
}
}
}
Vectorized<T> abs() const {
return _mm512_andnot_si512(_mm512_set1_epi8(0x80), values);
}
Vectorized<T> inline operator==(const Vectorizedf8<T>& other) const {
return binary_compare(other, [](__m512 x, __m512 y) {
auto zero_vec = _mm512_set1_epi32(0);
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_EQ_OQ);
return _mm512_castsi512_ps(
_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
});
}
Vectorized<T> inline operator!=(const Vectorizedf8<T>& other) const {
return binary_compare(other, [](__m512 x, __m512 y) {
auto zero_vec = _mm512_set1_epi32(0);
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_NEQ_UQ);
return _mm512_castsi512_ps(
_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
});
}
Vectorized<T> inline operator>(const Vectorizedf8<T>& other) const {
return binary_compare(other, [](__m512 x, __m512 y) {
auto zero_vec = _mm512_set1_epi32(0);
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GT_OQ);
return _mm512_castsi512_ps(
_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
});
}
Vectorized<T> inline operator>=(const Vectorizedf8<T>& other) const {
return binary_compare(other, [](__m512 x, __m512 y) {
auto zero_vec = _mm512_set1_epi32(0);
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ);
return _mm512_castsi512_ps(
_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
});
}
Vectorized<T> inline operator<(const Vectorizedf8<T>& other) const {
return binary_compare(other, [](__m512 x, __m512 y) {
auto zero_vec = _mm512_set1_epi32(0);
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LT_OQ);
return _mm512_castsi512_ps(
_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
});
}
Vectorized<T> inline operator<=(const Vectorizedf8<T>& other) const {
return binary_compare(other, [](__m512 x, __m512 y) {
auto zero_vec = _mm512_set1_epi32(0);
auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LE_OQ);
return _mm512_castsi512_ps(
_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF));
});
}
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free