test_packed Class — pytorch Architecture
Architecture documentation for the test_packed class in gemm-block-sparse-microkernel-tester.h from the pytorch codebase.
Entity Profile
Relationship Graph
Source Code
aten/src/ATen/native/quantized/cpu/qnnpack/test/gemm-block-sparse-microkernel-tester.h lines 358–493
template <typename SPARSE_INDICES_DTYPE, typename GEMM_UKERNEL_DTYPE>
void test_packed(
pytorch_q8gemm_sparse_packA_ukernel_function packa,
GEMM_UKERNEL_DTYPE qgemm) const {
ASSERT_LE(m(), mr());
ASSERT_LE(n(), nr());
std::random_device randomDevice;
auto rng = std::mt19937(randomDevice());
auto s32rng =
std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), rng);
auto u8rng = std::bind(std::uniform_int_distribution<uint8_t>(), rng);
std::vector<uint8_t> a((m() - 1) * aStride() + k() + 8);
std::vector<uint8_t> b(n() * k());
std::vector<float, AlignedAllocator<float, 32>> bias(std::max<size_t>(8, n()));
std::vector<float> c((m() - 1) * cStride() + n());
std::vector<float> acc(m() * n());
auto m_blocks = (m() + mr() - 1) / mr();
// While colBlockSize() is what kr is, we reuse 8x4/4x4 packing kernels
// and thus a_packed has to be allocated accordingly.
const uint32_t kr_value = 4;
auto k_blocks = (k() + kr_value - 1) / kr_value;
std::vector<uint8_t> a_packed((m_blocks * k_blocks * mr() * kr_value) + 8, 0);
const uint8_t* aPtr = a.data();
for (size_t iteration = 0; iteration < iterations(); iteration++) {
std::generate(a.begin(), a.end(), std::ref(u8rng));
std::generate(bias.begin(), bias.end(), std::ref(s32rng));
std::fill(c.begin(), c.end(), 0.0f);
size_t num_zero_points_padded = n() + 8;
std::vector<uint8_t> kernel_zero_points
(num_zero_points_padded, bZeroPoint());
uint8_t max_elem, min_elem;
// This loop to ensure the assert_ne on b mat does not fire.
do {
std::generate(b.begin(), b.end(), std::ref(u8rng));
fillBlockSparseWeights(
b.data(),
n(),
k(),
rowBlockSize(),
colBlockSize(),
sparsity(),
kernel_zero_points.data());
max_elem = *std::max_element(b.cbegin(), b.cend());
min_elem = *std::min_element(b.cbegin(), b.cend());
} while (max_elem == min_elem);
std::unique_ptr<qnnpack::BCSRMatrix> bcsr_matrix =
qnnpack::generateBlockCSRMatrix<SPARSE_INDICES_DTYPE>(
b.data(),
n(),
k(),
rowBlockSize(),
colBlockSize(),
kernel_zero_points.data());
ASSERT_NE(
*std::max_element(a.cbegin(), a.cend()),
*std::min_element(a.cbegin(), a.cend()));
ASSERT_NE(
*std::max_element(b.cbegin(), b.cend()),
*std::min_element(b.cbegin(), b.cend()));
auto f32rng =
std::bind(std::uniform_real_distribution<float>(1, 5), rng);
std::vector<float> dequantization_scales(num_zero_points_padded, 1.f);
std::generate(
dequantization_scales.begin(),
dequantization_scales.end(),
std::ref(f32rng));
/* Compute 32-bit results and output quantization arguments */
std::fill(acc.begin(), acc.end(), 0);
for (size_t mIndex = 0; mIndex < m(); mIndex++) {
for (size_t nIndex = 0; nIndex < n(); nIndex++) {
for (size_t kIndex = 0; kIndex < k(); kIndex++) {
ASSERT_LT(mIndex * n() + nIndex, acc.size());
ASSERT_LT(mIndex * k() + kIndex, a.size());
acc[mIndex * n() + nIndex] +=
(int32_t(aPtr[mIndex * aStride() + kIndex]) -
int32_t(aZeroPoint())) *
(int32_t(b[nIndex * k() + kIndex]) - int32_t(kernel_zero_points[nIndex]));
}
acc[mIndex * n() + nIndex] =
acc[mIndex * n() + nIndex] *
dequantization_scales[nIndex] +
bias[nIndex];
}
}
const struct pytorch_qnnp_conv_dynamic_quantization_params quantizationParams{
aZeroPoint(),
kernel_zero_points.data(),
dequantization_scales.data(),
};
packa(
m(),
k(),
aPtr,
aStride() * sizeof(uint8_t),
a_packed.data()
);
qgemm(
m(),
n(),
a_packed.data(),
bcsr_matrix->values.data(),
static_cast<const SPARSE_INDICES_DTYPE*>(
bcsr_matrix->row_values_data_ptr()),
static_cast<const SPARSE_INDICES_DTYPE*>(
bcsr_matrix->col_indices_data_ptr()),
bias.data(),
c.data(),
cStride(),
0,
&quantizationParams);
for (size_t mIndex = 0; mIndex < m(); mIndex++) {
for (size_t nIndex = 0; nIndex < n(); nIndex++) {
ASSERT_NEAR(
c[mIndex * cStride() + nIndex],
acc[mIndex * n() + nIndex],
std::abs(acc[mIndex * n() + nIndex]) * 1.0e-3f)
<< "at " << mIndex << ", " << nIndex
<< ": reference = " << acc[mIndex * n() + nIndex]
<< ", optimized = " << c[mIndex * cStride() + nIndex]
<< ", Mr x Nr = " << mr() << " x " << nr()
<< ", M x N x K = " << m() << " x " << n() << " x " << k();
}
}
}
}
Domain
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free