Home / Class/ transpose_mxn Class — pytorch Architecture

transpose_mxn Class — pytorch Architecture

Architecture documentation for the transpose_mxn class in vec512_bfloat16.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h lines 1187–1233

template <>
inline void transpose_mxn<BFloat16, 16, 16>(
    const BFloat16* src,
    int64_t ld_src,
    BFloat16* dst,
    int64_t ld_dst) {
  __m256i t[16];
  // load from src to registers
  // a: a0  a1  a2  a3  a4  a5  a6  a7  a8  a9  a10 a11 a12 a13 a14 a15
  // b: b0  b1  b2  b3  b4  b5  b6  b7  b8  b9  b10 b11 b12 b13 b14 b15
  // c: c0  c1  c2  c3  c4  c5  c6  c7  c8  c9  c10 c11 c12 c13 c14 c15
  // d: d0  d1  d2  d3  d4  d5  d6  d7  d8  d9  d10 d11 d12 d13 d14 d15
  // e: e0  e1  e2  e3  e4  e5  e6  e7  e8  e9  e10 e11 e12 e13 e14 e15
  // f: f0  f1  f2  f3  f4  f5  f6  f7  f8  f9  f10 f11 f12 f13 f14 f15
  // g: g0  g1  g2  g3  g4  g5  g6  g7  g8  g9  g10 g11 g12 g13 g14 g15
  // h: h0  h1  h2  h3  h4  h5  h6  h7  h8  h9  h10 h11 h12 h13 h14 h15
  // i: i0  i1  i2  i3  i4  i5  i6  i7  i8  i9  i10 i11 i12 i13 i14 i15
  // j: j0  j1  j2  j3  j4  j5  j6  j7  j8  j9  j10 j11 j12 j13 j14 j15
  // k: k0  k1  k2  k3  k4  k5  k6  k7  k8  k9  k10 k11 k12 k13 k14 k15
  // l: l0  l1  l2  l3  l4  l5  l6  l7  l8  l9  l10 l11 l12 l13 l14 l15
  // m: m0  m1  m2  m3  m4  m5  m6  m7  m8  m9  m10 m11 m12 m13 m14 m15
  // n: n0  n1  n2  n3  n4  n5  n6  n7  n8  n9  n10 n11 n12 n13 n14 n15
  // o: o0  o1  o2  o3  o4  o5  o6  o7  o8  o9  o10 o11 o12 o13 o14 o15
  // p: p0  p1  p2  p3  p4  p5  p6  p7  p8  p9  p10 p11 p12 p13 p14 p15
#ifndef __msvc_cl__
#pragma unroll(16)
#endif
  for (int i = 0; i < 16; i++) {
    t[i] =
        _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i * ld_src));
  }

  __m512i u[8];
  _transpose_mxn_half_16_16(t, u);

#ifndef __msvc_cl__
#pragma unroll(8)
#endif
  for (int i = 0; i < 8; i++) {
    _mm256_storeu_si256(
        reinterpret_cast<__m256i*>(dst + (i * 2) * ld_dst),
        _mm512_extracti32x8_epi32(u[i], 0x0));
    _mm256_storeu_si256(
        reinterpret_cast<__m256i*>(dst + (i * 2 + 1) * ld_dst),
        _mm512_extracti32x8_epi32(u[i], 0x01));
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free