Home / Class/ left_shift Class — pytorch Architecture

left_shift Class — pytorch Architecture

Architecture documentation for the left_shift class in vec512_int.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/cpu/vec/vec512/vec512_int.h lines 1845–2046

template <
    bool left_shift,
    typename T,
    typename std::enable_if_t<
        std::is_same_v<T, int8_t> || std::is_same_v<T, uint8_t>,
        int> = 0>
Vectorized<T> inline shift_512_8(
    const Vectorized<T>& a,
    const Vectorized<T>& b) {
  // No vector instruction for shifting int8_t/uint8_t, so emulating
  // it instead.

  // Control masks for shuffle operation, treating 512 bits as an
  // array of 8-bit elements, and considering pairs of neighboring
  // elements.  Specifically, a mask named "ctl_M_N" (M,N in [0,1], and
  // M!=N) is set so that shuffle will move element with index M from
  // input pair into element with index N in output pair, and element
  // with index M in output pair will be set to all 0s.
  __m512i ctl_0_1 = _mm512_set_epi8(
      62,
      0x80,
      60,
      0x80,
      58,
      0x80,
      56,
      0x80,
      54,
      0x80,
      52,
      0x80,
      50,
      0x80,
      48,
      0x80,
      46,
      0x80,
      44,
      0x80,
      42,
      0x80,
      40,
      0x80,
      38,
      0x80,
      36,
      0x80,
      34,
      0x80,
      32,
      0x80,
      30,
      0x80,
      28,
      0x80,
      26,
      0x80,
      24,
      0x80,
      22,
      0x80,
      20,
      0x80,
      18,
      0x80,
      16,
      0x80,
      14,
      0x80,
      12,
      0x80,
      10,
      0x80,
      8,
      0x80,
      6,
      0x80,
      4,
      0x80,
      2,
      0x80,
      0,
      0x80);
  __m512i ctl_1_0 = _mm512_set_epi8(
      0x80,
      63,
      0x80,
      61,
      0x80,
      59,
      0x80,
      57,
      0x80,
      55,
      0x80,
      53,
      0x80,
      51,
      0x80,
      49,
      0x80,
      47,
      0x80,
      45,
      0x80,
      43,
      0x80,
      41,
      0x80,
      39,
      0x80,
      37,
      0x80,
      35,
      0x80,
      33,
      0x80,
      31,
      0x80,
      29,
      0x80,
      27,
      0x80,
      25,
      0x80,
      23,
      0x80,
      21,
      0x80,
      19,
      0x80,
      17,
      0x80,
      15,
      0x80,
      13,
      0x80,
      11,
      0x80,
      9,
      0x80,
      7,
      0x80,
      5,
      0x80,
      3,
      0x80,
      1);

  // Masks for bitwise and operation, treating 512 bits as an array of
  // 8-bit elements, and considering them in pairs of neighboring
  // elements.  A mask named "keep_M" (M in [0,1]) is set so that
  // bitwise and will copy element with index M from input pair into
  // element with the same index in output pair, while the other
  // element in output pair will be set to all 0s.
  __m512i keep_0 = _mm512_set1_epi16(0xFF);
  __m512i keep_1 = _mm512_set1_epi16(0xFF00);

  // Take each 8-bit element with idx%2==0 from input array to be
  // shifted and extend it to 16 bits so that 0s are added to the
  // right.  Then, perform shifting on this 16-bit number.  Upper 8
  // bits will be proper result of shifting original 8-bit number, so
  // write them to result array, into the same position from which
  // corresponding input element is taken.  Also, make sure that
  // result array elements with idx%2!=0 are set to all 0s.
  //
  // Note that number of bits to shift for is extended to 16 bits by
  // adding 0s to the left.  That means this number is not properly
  // sign-extended for negative values.  However, number of bits to
  // shift is treated as an unsigned integer by respective shift
  // intrinsics anyway so if negative then either with or without
  // proper sign extension, it will be interpreted as a number greater
  // than 32, and the shifting result will be the same.
  __m512i a0 = _mm512_shuffle_epi8(a, ctl_0_1);
  __m512i b0 = _mm512_and_si512(b, keep_0);
  __m512i c0;
  if (left_shift)
    c0 = _mm512_sllv_epi16(a0, b0);
  else if constexpr (std::is_same_v<T, int8_t>)
    c0 = _mm512_srav_epi16(a0, b0);
  else
    c0 = _mm512_srlv_epi16(a0, b0);
  c0 = _mm512_shuffle_epi8(c0, ctl_1_0);

  // Perform shifting the same way for input array elements with
  // idx%2==1.
  __m512i a1 = _mm512_and_si512(a, keep_1);
  __m512i b1 = _mm512_shuffle_epi8(b, ctl_1_0);
  __m512i c1;
  if (left_shift)
    c1 = _mm512_sllv_epi16(a1, b1);
  else if constexpr (std::is_same_v<T, int8_t>)
    c1 = _mm512_srav_epi16(a1, b1);
  else
    c1 = _mm512_srlv_epi16(a1, b1);
  c1 = _mm512_and_si512(c1, keep_1);

  // Merge partial results into the final result.
  __m512i c = _mm512_or_si512(c0, c1);

  return c;
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free