CustomMmaMultistage Class — pytorch Architecture
Architecture documentation for the CustomMmaMultistage class in custom_mma_multistage.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_multistage.h lines 95–228
class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {
public:
///< Base class
using Base = CustomMmaBase<Shape_, Policy_, Stages>;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Iterates over tiles of A operand in global memory
using IteratorA = IteratorA_;
///< Iterates over tiles of B operand in global memory
using IteratorB = IteratorB_;
///< Data type of accumulator matrix
using ElementC = ElementC_;
///< Layout of accumulator matrix
using LayoutC = LayoutC_;
///< Policy describing tuning details
using Policy = Policy_;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
//
// Dependent types
//
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Minimum architecture is Sm80 to support cp.async
using ArchTag = arch::Sm80;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
/// Internal structure exposed for introspection.
struct Detail {
static_assert(
Base::kWarpGemmIterations > 1,
"The pipelined structure requires at least two warp-level "
"GEMM operations.");
/// Number of cp.async instructions to load one stage of operand A
static int const AsyncCopyIterationsPerStageA =
IteratorA::ThreadMap::Iterations::kCount;
/// Number of cp.async instructions to load one stage of operand B
static int const AsyncCopyIterationsPerStageB =
IteratorB::ThreadMap::Iterations::kCount;
/// Number of stages
static int const kStages = Stages;
/// Number of cp.async instructions to load on group of operand A
static int const kAccessesPerGroupA =
(AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) /
Base::kWarpGemmIterations;
/// Number of cp.async instructions to load on group of operand B
static int const kAccessesPerGroupB =
(AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) /
Base::kWarpGemmIterations;
};
static bool const kSmemContainsEntireMat = kMaxK <= Shape::kK * Stages;
static constexpr int kNumStagesConcurrentLoad =
kSmemContainsEntireMat ? Stages : Stages - 1;
private:
using WarpLoadedFragmentA = typename Operator::FragmentA;
using WarpLoadedFragmentB = typename Operator::FragmentB;
using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
private:
//
// Data members
//
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
bool prologue_done_;
// Set to `True` to ensure the accumulator will be zero outside the GEMM
// footprint
bool zero_outside_bounds_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
CustomMmaMultistage(
///< Shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorageA& shared_storageA,
typename Base::SharedStorageB& shared_storageB,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx),
smem_iterator_A_(shared_storageA.ref(), thread_idx),
smem_iterator_B_(shared_storageB.ref(), thread_idx),
prologue_done_(false),
zero_outside_bounds_(false) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset(
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free