kTransformA Class — pytorch Architecture
Architecture documentation for the kTransformA class in custom_mma_pipelined.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_pipelined.h lines 60–206
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Transformation applied to A operand
typename TransformA_ = NumericArrayConverter<
typename SmemIteratorA_::Element,
typename IteratorA_::Element,
IteratorA_::Fragment::kElements>,
///
/// Transformation applied to B operand
typename TransformB_ = NumericArrayConverter<
typename SmemIteratorB_::Element,
typename IteratorB_::Element,
IteratorB_::Fragment::kElements>,
/// Used for partial specialization
typename Enable = bool>
class CustomMmaPipelined : public CustomMmaBase<Shape_, Policy_, 2> {
public:
///< Base class
using Base = CustomMmaBase<Shape_, Policy_, 2>;
using Shape =
Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using IteratorA =
IteratorA_; ///< Iterates over tiles of A operand in global memory
using IteratorB =
IteratorB_; ///< Iterates over tiles of B operand in global memory
using ElementC = ElementC_; ///< Data type of accumulator matrix
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
using Policy = Policy_; ///< Policy describing tuning details
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
using TransformA = TransformA_;
using TransformB = TransformB_;
//
// Dependent types
//
/// Fragment of operand A loaded from global memory
using FragmentA = typename IteratorA::Fragment;
/// Fragment of operand B loaded from global memory
using FragmentB = typename IteratorB::Fragment;
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Obtain the arch tag from the warp-level operator
using ArchTag = typename Policy::Operator::ArchTag;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
// statically assert kStages for MmaPipelined is two (Double-buffered pipeline)
static_assert(
(Base::kStages == 2),
"MmaPipelined requires kStages set to value 2");
static bool const kSmemContainsEntireMat = false;
private:
using WarpFragmentA = typename Operator::FragmentA;
using WarpFragmentB = typename Operator::FragmentB;
protected:
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
CustomMmaPipelined(
typename Base::SharedStorageA& shared_storageA,
typename Base::SharedStorageB& shared_storageB,
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx ///< ID of each thread within a warp
)
: Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx),
smem_iterator_A_(shared_storageA.ref(), thread_idx),
smem_iterator_B_(shared_storageB.ref(), thread_idx) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset(
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
}
CUTLASS_DEVICE
CustomMmaPipelined(
///< Shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage& st,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: CustomMmaPipelined(
st.operand_A,
st.operand_B,
thread_idx,
warp_idx,
lane_idx) {}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free