Home / Class/ void Class — pytorch Architecture

void Class — pytorch Architecture

Architecture documentation for the void class in vec512_qint.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/cpu/vec/vec512/vec512_qint.h lines 181–333

template <typename T>
__FORCE_INLINE void QuantizeAvx512(
    const float* src,
    T* dst,
    int len,
    float inverse_scale,
    int64_t zero_point) {
  constexpr int VLEN = 16;
  constexpr auto min_val = std::numeric_limits<T>::min();
  constexpr auto max_val = std::numeric_limits<T>::max();
  const __m512i min_v = _mm512_set1_epi32(min_val);
  const __m512i max_v = _mm512_set1_epi32(max_val);
  // This is the largest int32 value < int32_max exactly representable in float
  constexpr int32_t int32_float_max_val =
      std::numeric_limits<int32_t>::max() - 127;
  int i = 0;
  __m512 inverse_scale_v = _mm512_set1_ps(inverse_scale);
  // clang-format off
  static const __m512i shuffle_mask_v = _mm512_set_epi8(
      0xff, 0xff, 0xff, 0xff,
      0xff, 0xff, 0xff, 0xff,
      0xff, 0xff, 0xff, 0xff,
      0x0c, 0x08, 0x04, 0x00,
      0xff, 0xff, 0xff, 0xff,
      0xff, 0xff, 0xff, 0xff,
      0xff, 0xff, 0xff, 0xff,
      0x0c, 0x08, 0x04, 0x00,
      0xff, 0xff, 0xff, 0xff,
      0xff, 0xff, 0xff, 0xff,
      0xff, 0xff, 0xff, 0xff,
      0x0c, 0x08, 0x04, 0x00,
      0xff, 0xff, 0xff, 0xff,
      0xff, 0xff, 0xff, 0xff,
      0xff, 0xff, 0xff, 0xff,
      0x0c, 0x08, 0x04, 0x00);
  // clang-format on
  __m512i permute_mask_v = _mm512_set_epi32(
      0x0f,
      0x0b,
      0x07,
      0x03,
      0x0e,
      0x0a,
      0x06,
      0x02,
      0x0d,
      0x09,
      0x05,
      0x01,
      0x0c,
      0x08,
      0x04,
      0x00);
  __m512i permute_mask_l8_v = _mm512_set_epi32(
      0x00,
      0x00,
      0x00,
      0x00,
      0x00,
      0x00,
      0x00,
      0x00,
      0x00,
      0x00,
      0x00,
      0x00,
      0x0c,
      0x08,
      0x04,
      0x00);
  int len_aligned = len / (VLEN * 4) * (VLEN * 4);
  for (; i < len_aligned; i += 4 * VLEN) {
    // x
    __m512 x_vals = _mm512_load_ps(src + i);
    __m512 x_transformed_v = _mm512_mul_ps(x_vals, inverse_scale_v);
    // If the floating point value is greater than int32_max,
    // _mm512_cvtps_epi32 converts them to -ve. Clip at int32_float_max_val to
    // Clip at int32_float_max_val to avoid this.
    x_transformed_v =
        _mm512_min_ps(x_transformed_v, _mm512_set1_ps(int32_float_max_val));
    // y
    __m512 y_vals = _mm512_load_ps(src + i + VLEN);
    __m512 y_transformed_v = _mm512_mul_ps(y_vals, inverse_scale_v);
    y_transformed_v =
        _mm512_min_ps(y_transformed_v, _mm512_set1_ps(int32_float_max_val));
    // z
    __m512 z_vals = _mm512_load_ps(src + i + 2 * VLEN);
    __m512 z_transformed_v = _mm512_mul_ps(z_vals, inverse_scale_v);
    z_transformed_v =
        _mm512_min_ps(z_transformed_v, _mm512_set1_ps(int32_float_max_val));
    // w
    __m512 w_vals = _mm512_load_ps(src + i + 3 * VLEN);
    __m512 w_transformed_v = _mm512_mul_ps(w_vals, inverse_scale_v);
    w_transformed_v =
        _mm512_min_ps(w_transformed_v, _mm512_set1_ps(int32_float_max_val));

    __m512i x_rounded_v = _mm512_cvtps_epi32(x_transformed_v);
    __m512i y_rounded_v = _mm512_cvtps_epi32(y_transformed_v);
    __m512i z_rounded_v = _mm512_cvtps_epi32(z_transformed_v);
    __m512i w_rounded_v = _mm512_cvtps_epi32(w_transformed_v);

    // add zero point
    x_rounded_v = _mm512_add_epi32(x_rounded_v, _mm512_set1_epi32(zero_point));
    y_rounded_v = _mm512_add_epi32(y_rounded_v, _mm512_set1_epi32(zero_point));
    z_rounded_v = _mm512_add_epi32(z_rounded_v, _mm512_set1_epi32(zero_point));
    w_rounded_v = _mm512_add_epi32(w_rounded_v, _mm512_set1_epi32(zero_point));

    __m512i xy_packed_v = _mm512_packs_epi32(x_rounded_v, y_rounded_v);
    __m512i zw_packed_v = _mm512_packs_epi32(z_rounded_v, w_rounded_v);
    __m512i xyzw_clamped_v =
        pack_saturate_and_clamp<T>(xy_packed_v, zw_packed_v, min_val, max_val);

    xyzw_clamped_v = _mm512_permutexvar_epi32(permute_mask_v, xyzw_clamped_v);
    _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + i), xyzw_clamped_v);
  }

  // Additional 8-lane AVX512 version to take advantage when len is smaller
  // based on fbgemm::QuantizeAvx2 (https://github.com/pytorch/FBGEMM)
  for (; i < len / VLEN * VLEN; i += VLEN) {
    __m512 x_vals = _mm512_load_ps(src + i);
    __m512 x_transformed_v = _mm512_mul_ps(x_vals, inverse_scale_v);
    x_transformed_v =
        _mm512_min_ps(x_transformed_v, _mm512_set1_ps(int32_float_max_val));
    __m512i x_rounded_v = _mm512_cvtps_epi32(x_transformed_v);
    x_rounded_v = _mm512_add_epi32(x_rounded_v, _mm512_set1_epi32(zero_point));
    __m512i x_clipped_v =
        _mm512_max_epi32(min_v, _mm512_min_epi32(max_v, x_rounded_v));

    x_clipped_v = _mm512_shuffle_epi8(x_clipped_v, shuffle_mask_v);
    x_clipped_v = _mm512_permutexvar_epi32(permute_mask_l8_v, x_clipped_v);
    _mm_storeu_si128(
        reinterpret_cast<__m128i*>(dst + i),
        _mm512_castsi512_si128(x_clipped_v));
  }

  for (; i < len; ++i) {
    float transformed = src[i] * inverse_scale;

    // Not exactly the same behavior as the vectorized code.
    // The vectorized code above always rounds to even in halfway cases
    // (https://software.intel.com/en-us/node/523819), but std::nearbyint
    // does the same only when the current rounding mode is FE_TONEAREST.
    // However, in practice, this should not be a problem because most cases
    // use the default rounding mode FE_TONEAREST.
    // Note that we cannot implement the same behavior as the vectorized code
    // using std::round because it does rounding away from zero in halfway
    // cases.
    transformed = zero_point + std::nearbyint(transformed);
    float clipped =
        std::min(std::max(transformed, float(min_val)), float(max_val));
    dst[i] = clipped;
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free