kAlignmentA Class — pytorch Architecture
Architecture documentation for the kAlignmentA class in default_dq_mma_pipelined.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cuda/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h lines 19–156
template<
/// Type for element A
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Type for element B
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for the input scale
typename ElementScale,
/// Layout for the scale operand
typename LayoutScale,
/// Access granularity of Scales in unit of elements
int kAlignmentScale,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DqMma<ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementScale,
LayoutScale,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator,
SharedMemoryClearOption::kNone,
typename platform::enable_if<(ArchTag::kMinComputeCapability < 80)>::type> {
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
"Element A must be fp16 or bf16");
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
"Element B must be uint8 or uint4");
static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80;
using MmaCoreElementA = typename platform::conditional<arch_has_bf16_mma, ElementA, half_t>::type;
using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
// Define the MmaCore components
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
WarpShape,
InstructionShape,
MmaCoreElementA,
LayoutA,
MmaCoreElementB,
LayoutB,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
2,
Operator>;
// Define iterators over tiles from the A operand
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
ElementA,
LayoutA,
1,
typename MmaCore::IteratorThreadMapA,
kAlignmentA>;
// Define iterators over tiles from the B operand
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,
ElementB,
LayoutB,
0,
typename MmaCore::IteratorThreadMapB,
kAlignmentB>;
// ThreadMap for scale iterator
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
using IteratorScaleThreadMap =
transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
MmaCore::Shape::kN / kAlignmentScale,
kAlignmentScale>;
// Define iterators over tiles from the scale operand
using IteratorScale =
cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
ElementScale,
LayoutScale,
0,
IteratorScaleThreadMap,
kAlignmentScale>;
using SmemScaleType = typename platform::conditional<arch_has_bf16_mma, ElementScale, half_t>::type;
using SmemIteratorScale =
cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
SmemScaleType,
LayoutScale,
0,
IteratorScaleThreadMap,
kAlignmentScale>;
using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<typename MmaCore::Shape,
IteratorA,
typename MmaCore::SmemIteratorA,
IteratorB,
typename MmaCore::SmemIteratorB,
IteratorScale,
SmemIteratorScale,
ElementAccumulator,
layout::RowMajor,
typename MmaCore::MmaPolicy,
typename Converters::TransformAfterLDG,
typename Converters::TransformAfterLDS>;
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free