Home / Class/ binary_compare Class — pytorch Architecture

binary_compare Class — pytorch Architecture

Architecture documentation for the binary_compare class in vec512_float8.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/cpu/vec/vec512/vec512_float8.h lines 269–318

  template <typename Op, typename VectorizedType>
  Vectorized<T> inline binary_compare(const VectorizedType& b, Op op) const {
    __m512 a0, a1, a2, a3;
    __m512 b0, b1, b2, b3;
    __m512 o0, o1, o2, o3;
    if constexpr (std::is_same_v<T, c10::Float8_e4m3fn>) {
      cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(values, 0), a0);
      cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b.values, 0), b0);
      cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(values, 1), a1);
      cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b.values, 1), b1);
      cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(values, 2), a2);
      cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b.values, 2), b2);
      cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(values, 3), a3);
      cvtfp8e4m3_fp32(_mm512_extracti32x4_epi32(b.values, 3), b3);
    } else {
      cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(values, 0), a0);
      cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b.values, 0), b0);
      cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(values, 1), a1);
      cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b.values, 1), b1);
      cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(values, 2), a2);
      cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b.values, 2), b2);
      cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(values, 3), a3);
      cvtfp8e5m2_fp32(_mm512_extracti32x4_epi32(b.values, 3), b3);
    }

    o0 = op(a0, b0);
    o1 = op(a1, b1);
    o2 = op(a2, b2);
    o3 = op(a3, b3);
    __m128i o128_0, o128_1, o128_2, o128_3;
    if constexpr (std::is_same_v<T, c10::Float8_e4m3fn>) {
      o128_0 = cvtfp32_fp8e4m3(o0);
      o128_1 = cvtfp32_fp8e4m3(o1);
      o128_2 = cvtfp32_fp8e4m3(o2);
      o128_3 = cvtfp32_fp8e4m3(o3);
    } else {
      o128_0 = cvtfp32_fp8e5m2(o0);
      o128_1 = cvtfp32_fp8e5m2(o1);
      o128_2 = cvtfp32_fp8e5m2(o2);
      o128_3 = cvtfp32_fp8e5m2(o3);
    }

    __m512i result = _mm512_setzero_si512();
    result = _mm512_inserti32x4(result, o128_0, 0);
    result = _mm512_inserti32x4(result, o128_1, 1);
    result = _mm512_inserti32x4(result, o128_2, 2);
    result = _mm512_inserti32x4(result, o128_3, 3);

    return result;
  }

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free