left_shift Class — pytorch Architecture
Architecture documentation for the left_shift class in vec512_int.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/cpu/vec/vec512/vec512_int.h lines 1845–2046
template <
bool left_shift,
typename T,
typename std::enable_if_t<
std::is_same_v<T, int8_t> || std::is_same_v<T, uint8_t>,
int> = 0>
Vectorized<T> inline shift_512_8(
const Vectorized<T>& a,
const Vectorized<T>& b) {
// No vector instruction for shifting int8_t/uint8_t, so emulating
// it instead.
// Control masks for shuffle operation, treating 512 bits as an
// array of 8-bit elements, and considering pairs of neighboring
// elements. Specifically, a mask named "ctl_M_N" (M,N in [0,1], and
// M!=N) is set so that shuffle will move element with index M from
// input pair into element with index N in output pair, and element
// with index M in output pair will be set to all 0s.
__m512i ctl_0_1 = _mm512_set_epi8(
62,
0x80,
60,
0x80,
58,
0x80,
56,
0x80,
54,
0x80,
52,
0x80,
50,
0x80,
48,
0x80,
46,
0x80,
44,
0x80,
42,
0x80,
40,
0x80,
38,
0x80,
36,
0x80,
34,
0x80,
32,
0x80,
30,
0x80,
28,
0x80,
26,
0x80,
24,
0x80,
22,
0x80,
20,
0x80,
18,
0x80,
16,
0x80,
14,
0x80,
12,
0x80,
10,
0x80,
8,
0x80,
6,
0x80,
4,
0x80,
2,
0x80,
0,
0x80);
__m512i ctl_1_0 = _mm512_set_epi8(
0x80,
63,
0x80,
61,
0x80,
59,
0x80,
57,
0x80,
55,
0x80,
53,
0x80,
51,
0x80,
49,
0x80,
47,
0x80,
45,
0x80,
43,
0x80,
41,
0x80,
39,
0x80,
37,
0x80,
35,
0x80,
33,
0x80,
31,
0x80,
29,
0x80,
27,
0x80,
25,
0x80,
23,
0x80,
21,
0x80,
19,
0x80,
17,
0x80,
15,
0x80,
13,
0x80,
11,
0x80,
9,
0x80,
7,
0x80,
5,
0x80,
3,
0x80,
1);
// Masks for bitwise and operation, treating 512 bits as an array of
// 8-bit elements, and considering them in pairs of neighboring
// elements. A mask named "keep_M" (M in [0,1]) is set so that
// bitwise and will copy element with index M from input pair into
// element with the same index in output pair, while the other
// element in output pair will be set to all 0s.
__m512i keep_0 = _mm512_set1_epi16(0xFF);
__m512i keep_1 = _mm512_set1_epi16(0xFF00);
// Take each 8-bit element with idx%2==0 from input array to be
// shifted and extend it to 16 bits so that 0s are added to the
// right. Then, perform shifting on this 16-bit number. Upper 8
// bits will be proper result of shifting original 8-bit number, so
// write them to result array, into the same position from which
// corresponding input element is taken. Also, make sure that
// result array elements with idx%2!=0 are set to all 0s.
//
// Note that number of bits to shift for is extended to 16 bits by
// adding 0s to the left. That means this number is not properly
// sign-extended for negative values. However, number of bits to
// shift is treated as an unsigned integer by respective shift
// intrinsics anyway so if negative then either with or without
// proper sign extension, it will be interpreted as a number greater
// than 32, and the shifting result will be the same.
__m512i a0 = _mm512_shuffle_epi8(a, ctl_0_1);
__m512i b0 = _mm512_and_si512(b, keep_0);
__m512i c0;
if (left_shift)
c0 = _mm512_sllv_epi16(a0, b0);
else if constexpr (std::is_same_v<T, int8_t>)
c0 = _mm512_srav_epi16(a0, b0);
else
c0 = _mm512_srlv_epi16(a0, b0);
c0 = _mm512_shuffle_epi8(c0, ctl_1_0);
// Perform shifting the same way for input array elements with
// idx%2==1.
__m512i a1 = _mm512_and_si512(a, keep_1);
__m512i b1 = _mm512_shuffle_epi8(b, ctl_1_0);
__m512i c1;
if (left_shift)
c1 = _mm512_sllv_epi16(a1, b1);
else if constexpr (std::is_same_v<T, int8_t>)
c1 = _mm512_srav_epi16(a1, b1);
else
c1 = _mm512_srlv_epi16(a1, b1);
c1 = _mm512_and_si512(c1, keep_1);
// Merge partial results into the final result.
__m512i c = _mm512_or_si512(c0, c1);
return c;
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free