Vectorized8 Class — pytorch Architecture
Architecture documentation for the Vectorized8 class in vec512_int.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/cpu/vec/vec512/vec512_int.h lines 728–1138
class Vectorized8 : public Vectorizedi {
static_assert(
std::is_same_v<T, int8_t> || std::is_same_v<T, uint8_t>,
"Only int8_t/uint8_t are supported");
protected:
static constexpr __m512i zero_vector{0, 0, 0, 0, 0, 0, 0, 0};
static const Vectorized<T> ones;
public:
using value_type = T;
static constexpr int size() {
return 64;
}
using Vectorizedi::Vectorizedi;
Vectorized8() {}
Vectorized8(T v) {
values = _mm512_set1_epi8(v);
}
Vectorized8(
T val1,
T val2,
T val3,
T val4,
T val5,
T val6,
T val7,
T val8,
T val9,
T val10,
T val11,
T val12,
T val13,
T val14,
T val15,
T val16,
T val17,
T val18,
T val19,
T val20,
T val21,
T val22,
T val23,
T val24,
T val25,
T val26,
T val27,
T val28,
T val29,
T val30,
T val31,
T val32,
T val33,
T val34,
T val35,
T val36,
T val37,
T val38,
T val39,
T val40,
T val41,
T val42,
T val43,
T val44,
T val45,
T val46,
T val47,
T val48,
T val49,
T val50,
T val51,
T val52,
T val53,
T val54,
T val55,
T val56,
T val57,
T val58,
T val59,
T val60,
T val61,
T val62,
T val63,
T val64) {
values = _mm512_set_epi8(
val64,
val63,
val62,
val61,
val60,
val59,
val58,
val57,
val56,
val55,
val54,
val53,
val52,
val51,
val50,
val49,
val48,
val47,
val46,
val45,
val44,
val43,
val42,
val41,
val40,
val39,
val38,
val37,
val36,
val35,
val34,
val33,
val32,
val31,
val30,
val29,
val28,
val27,
val26,
val25,
val24,
val23,
val22,
val21,
val20,
val19,
val18,
val17,
val16,
val15,
val14,
val13,
val12,
val11,
val10,
val9,
val8,
val7,
val6,
val5,
val4,
val3,
val2,
val1);
}
template <int64_t mask>
static Vectorized<T> blend(Vectorized<T> a, Vectorized<T> b) {
return _mm512_mask_blend_epi8(mask, a.values, b.values);
}
template <typename step_t>
static Vectorized<T> arange(
T base = 0,
step_t step = static_cast<step_t>(1)) {
return Vectorized<T>(
base,
base + step,
base + 2 * step,
base + 3 * step,
base + 4 * step,
base + 5 * step,
base + 6 * step,
base + 7 * step,
base + 8 * step,
base + 9 * step,
base + 10 * step,
base + 11 * step,
base + 12 * step,
base + 13 * step,
base + 14 * step,
base + 15 * step,
base + 16 * step,
base + 17 * step,
base + 18 * step,
base + 19 * step,
base + 20 * step,
base + 21 * step,
base + 22 * step,
base + 23 * step,
base + 24 * step,
base + 25 * step,
base + 26 * step,
base + 27 * step,
base + 28 * step,
base + 29 * step,
base + 30 * step,
base + 31 * step,
base + 32 * step,
base + 33 * step,
base + 34 * step,
base + 35 * step,
base + 36 * step,
base + 37 * step,
base + 38 * step,
base + 39 * step,
base + 40 * step,
base + 41 * step,
base + 42 * step,
base + 43 * step,
base + 44 * step,
base + 45 * step,
base + 46 * step,
base + 47 * step,
base + 48 * step,
base + 49 * step,
base + 50 * step,
base + 51 * step,
base + 52 * step,
base + 53 * step,
base + 54 * step,
base + 55 * step,
base + 56 * step,
base + 57 * step,
base + 58 * step,
base + 59 * step,
base + 60 * step,
base + 61 * step,
base + 62 * step,
base + 63 * step);
}
static Vectorized<T> set(Vectorized<T> a, Vectorized<T> b, T count = size()) {
switch (count) {
case 0:
return a;
case 1:
return blend<0x1>(a, b);
case 2:
return blend<0x3>(a, b);
case 3:
return blend<0x7>(a, b);
case 4:
return blend<0xF>(a, b);
case 5:
return blend<0x1F>(a, b);
case 6:
return blend<0x3F>(a, b);
case 7:
return blend<0x7F>(a, b);
case 8:
return blend<0xFF>(a, b);
case 9:
return blend<0x1FF>(a, b);
case 10:
return blend<0x3FF>(a, b);
case 11:
return blend<0x7FF>(a, b);
case 12:
return blend<0xFFF>(a, b);
case 13:
return blend<0x1FFF>(a, b);
case 14:
return blend<0x3FFF>(a, b);
case 15:
return blend<0x7FFF>(a, b);
case 16:
return blend<0xFFFF>(a, b);
case 17:
return blend<0x1FFFF>(a, b);
case 18:
return blend<0x3FFFF>(a, b);
case 19:
return blend<0x7FFFF>(a, b);
case 20:
return blend<0xFFFFF>(a, b);
case 21:
return blend<0x1FFFFF>(a, b);
case 22:
return blend<0x3FFFFF>(a, b);
case 23:
return blend<0x7FFFFF>(a, b);
case 24:
return blend<0xFFFFFF>(a, b);
case 25:
return blend<0x1FFFFFF>(a, b);
case 26:
return blend<0x3FFFFFF>(a, b);
case 27:
return blend<0x7FFFFFF>(a, b);
case 28:
return blend<0xFFFFFFF>(a, b);
case 29:
return blend<0x1FFFFFFF>(a, b);
case 30:
return blend<0x3FFFFFFF>(a, b);
case 31:
return blend<0x7FFFFFFF>(a, b);
case 32:
return blend<0xFFFFFFFF>(a, b);
case 33:
return blend<0x1FFFFFFFF>(a, b);
case 34:
return blend<0x3FFFFFFFF>(a, b);
case 35:
return blend<0x7FFFFFFFF>(a, b);
case 36:
return blend<0xFFFFFFFFF>(a, b);
case 37:
return blend<0x1FFFFFFFFF>(a, b);
case 38:
return blend<0x3FFFFFFFFF>(a, b);
case 39:
return blend<0x7FFFFFFFFF>(a, b);
case 40:
return blend<0xFFFFFFFFFF>(a, b);
case 41:
return blend<0x1FFFFFFFFFF>(a, b);
case 42:
return blend<0x3FFFFFFFFFF>(a, b);
case 43:
return blend<0x7FFFFFFFFFF>(a, b);
case 44:
return blend<0xFFFFFFFFFFF>(a, b);
case 45:
return blend<0x1FFFFFFFFFFF>(a, b);
case 46:
return blend<0x3FFFFFFFFFFF>(a, b);
case 47:
return blend<0x7FFFFFFFFFFF>(a, b);
case 48:
return blend<0xFFFFFFFFFFFF>(a, b);
case 49:
return blend<0x1FFFFFFFFFFFF>(a, b);
case 50:
return blend<0x3FFFFFFFFFFFF>(a, b);
case 51:
return blend<0x7FFFFFFFFFFFF>(a, b);
case 52:
return blend<0xFFFFFFFFFFFFF>(a, b);
case 53:
return blend<0x1FFFFFFFFFFFFF>(a, b);
case 54:
return blend<0x3FFFFFFFFFFFFF>(a, b);
case 55:
return blend<0x7FFFFFFFFFFFFF>(a, b);
case 56:
return blend<0xFFFFFFFFFFFFFF>(a, b);
case 57:
return blend<0x1FFFFFFFFFFFFFF>(a, b);
case 58:
return blend<0x3FFFFFFFFFFFFFF>(a, b);
case 59:
return blend<0x7FFFFFFFFFFFFFF>(a, b);
case 60:
return blend<0xFFFFFFFFFFFFFFF>(a, b);
case 61:
return blend<0x1FFFFFFFFFFFFFFF>(a, b);
case 62:
return blend<0x3FFFFFFFFFFFFFFF>(a, b);
case 63:
return blend<0x7FFFFFFFFFFFFFFF>(a, b);
}
return b;
}
static Vectorized<T> loadu(const void* ptr) {
return _mm512_loadu_si512(reinterpret_cast<const __m512i*>(ptr));
}
static Vectorized<T> loadu_one_fourth(const void* ptr) {
// Fast path if only load element number of 16.
// Note: We didn't merge it as fast path of loadu(const void* ptr, T count),
// Because loadu(const void* ptr, T count) requires zero initialization for
// upper 384 bits. However, by using _mm512_castsi128_si512, the upper 384
// bits of the result are undefined.
// TODO<leslie> We can use _mm512_zextsi128_si512 in the future,
// since gcc 9.3 doesn't support it now.
__m128i input_128 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(ptr));
return _mm512_castsi128_si512(input_128);
}
static Vectorized<T> loadu(const void* ptr, T count) {
if (count == size()) {
return _mm512_loadu_si512(reinterpret_cast<const __m512i*>(ptr));
} else if (count == 16) {
// Fast path if only load element number of 16
return loadu_one_fourth(ptr);
} else {
__mmask64 mask = (1ULL << count) - 1;
auto ones = _mm512_set1_epi8(1);
return _mm512_mask_loadu_epi8(ones, mask, ptr);
}
}
void store(void* ptr, int count = size()) const {
if (count == size()) {
// ptr need not to be aligned here. See
// https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm512-storeu-si512.html
_mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values);
} else if (count > 0) {
if (count == 16) {
// Fast path if only store element number of 16
_mm_storeu_si128(
reinterpret_cast<__m128i*>(ptr), _mm512_castsi512_si128(values));
} else {
__mmask64 mask = (1ULL << count) - 1;
_mm512_mask_storeu_epi8(ptr, mask, values);
}
}
}
const T& operator[](int idx) const = delete;
T& operator[](int idx) = delete;
Vectorized<T> real() const {
return *this;
}
Vectorized<T> imag() const {
return _mm512_set1_epi8(0);
}
Vectorized<T> conj() const {
return *this;
}
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free