Home / Class/ RequantizeAvx512 Class — pytorch Architecture

RequantizeAvx512 Class — pytorch Architecture

Architecture documentation for the RequantizeAvx512 class in vec512_qint.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/cpu/vec/vec512/vec512_qint.h lines 492–549

template <typename T>
__m512i RequantizeAvx512(
    const std::array<Vectorized<c10::qint32>, 4>& inp,
    __m512 multiplier,
    __m512i zp) {
  static_assert(
      std::is_same_v<T, int8_t> || std::is_same_v<T, uint8_t>,
      "Only int8_t/uint8_t are supported");
  constexpr auto min_val = std::numeric_limits<T>::min();
  constexpr auto max_val = std::numeric_limits<T>::max();
  __m512i permute_mask_v = _mm512_set_epi32(
      0x0f,
      0x0b,
      0x07,
      0x03,
      0x0e,
      0x0a,
      0x06,
      0x02,
      0x0d,
      0x09,
      0x05,
      0x01,
      0x0c,
      0x08,
      0x04,
      0x00);
  __m512 x_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[0]), multiplier);
  __m512 y_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[1]), multiplier);
  __m512 z_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[2]), multiplier);
  __m512 w_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[3]), multiplier);

  __m512i x_rounded_v = _mm512_cvtps_epi32(x_scaled_v);
  __m512i y_rounded_v = _mm512_cvtps_epi32(y_scaled_v);
  __m512i z_rounded_v = _mm512_cvtps_epi32(z_scaled_v);
  __m512i w_rounded_v = _mm512_cvtps_epi32(w_scaled_v);

  /* Add zero point */
  __m512i x_v = _mm512_add_epi32(x_rounded_v, zp);
  __m512i y_v = _mm512_add_epi32(y_rounded_v, zp);
  __m512i z_v = _mm512_add_epi32(z_rounded_v, zp);
  __m512i w_v = _mm512_add_epi32(w_rounded_v, zp);

  /* Pack to int16_t and saturate */
  __m512i xy_packed_v = _mm512_packs_epi32(x_v, y_v);
  __m512i zw_packed_v = _mm512_packs_epi32(z_v, w_v);

  __m512i xyzw_clamped_v =
      pack_saturate_and_clamp<T>(xy_packed_v, zw_packed_v, min_val, max_val);

  /*
   * xyzw_clamped_v has results in the following layout so we need to
   * permute: x0-3 y0-3 z0-3 w0-3 x4-7 y4-7 z4-7 w4-7 x8-11 y8-11 z8-11 w8-11
   * x12-15 y12-15 z12-15 w12-15
   */
  xyzw_clamped_v = _mm512_permutexvar_epi32(permute_mask_v, xyzw_clamped_v);
  return xyzw_clamped_v;
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free