Operand_ Class — pytorch Architecture
Architecture documentation for the Operand_ class in warp_iterator_from_smem.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/transformers/cuda/mem_eff_attention/iterators/warp_iterator_from_smem.h lines 54–238
template <
/// Operand identity
Operand Operand_,
/// Data type of A elements
typename Element_,
typename InstructionShape_,
bool kTranspose = false>
class WarpIteratorFromSmem {
public:
/// Shape of tile to load (concept: MatrixShape)
using Shape = cutlass::MatrixShape<32, 32>;
/// Operand tag
static Operand const kOperand = Operand_;
static_assert(
kOperand == Operand::kA,
"No support for OperandB at the moment");
/// Basic check
static_assert(
kOperand == Operand::kA || kOperand == Operand::kB,
"WarpIteratorFromSmem may only be instantiated for A or B operands to warp-level Mma.");
/// Element type
using Element = Element_;
static_assert(sizeof_bits<Element>::value == 16, "Only supported for half");
/// Layout of source tile
using Layout = cutlass::layout::RowMajor;
/// Shape of one matrix product operation (concept: MatrixShape)
using InstructionShape = InstructionShape_;
static_assert(InstructionShape::kRow == 16, "Only supports 16x8x8 / 16x8x16");
static_assert(
InstructionShape::kColumn == 8 || InstructionShape::kColumn == 16,
"Only supports 16x8x8 / 16x8x16");
/// Delta between *MMA operations (in units of *MMA operations, concept:
/// MatrixShape)
static int const kOpDelta = 1;
/// Number of participating threads
static int const kThreads = 32;
/// TensorRef type for loading element from a tensor
using TensorRef = TensorRef<Element, Layout>;
/// Index type
using Index = typename TensorRef::Index;
/// Long Index type
using LongIndex = typename TensorRef::LongIndex;
/// Coordinate for an element in the tensor
using TensorCoord = typename TensorRef::TensorCoord;
/// Number of elements accessed per Shared Memory load
static int const kElementsPerAccess =
(sizeof_bits<Element>::value >= 32 ? 1
: 32 / sizeof_bits<Element>::value);
using InstructionCount = MatrixShape<
Shape::kRow / InstructionShape::kRow,
Shape::kColumn / InstructionShape::kColumn>;
static int const kIterations = (kOperand == Operand::kA)
? InstructionCount::kColumn
: InstructionCount::kRow;
public:
//
// Derived quantities
//
/// Fragment object holding a thread's part of a tile
using Fragment = Array<
Element,
(kOperand == Operand::kA)
? (Shape::kRow* InstructionShape::kColumn / kThreads)
: (Shape::kColumn* InstructionShape::kRow / kThreads)>;
/// Memory access type
// using AccessType = AlignedArray<Element, kElementsPerAccess>;
using AccessType = Array<unsigned, 4>;
static int constexpr kWarpShapeDivisibleInner =
(kOperand == Operand::kA ? InstructionShape::kColumn
: InstructionShape::kRow);
static int constexpr kAccessesInner =
(kWarpShapeDivisibleInner / kElementsPerAccess) / 4;
// Number of 32bits tiles to load per `ldmatrix`
static int const kTilesPerInstruction = InstructionShape::kRow / 8;
static_assert(kTilesPerInstruction == 2, "Only supports 16x8x16 and 16x8x8");
private:
/// Underlying tensor reference
TensorRef ref_;
/// Origin
MatrixCoord origin_;
/// Iterations in a tile
int iterations_;
public:
/// Constructor from TensorRef
CUTLASS_HOST_DEVICE
WarpIteratorFromSmem(TensorRef const& ref, int lane_id)
: WarpIteratorFromSmem(ref, {Shape::kRow, Shape::kColumn}, lane_id) {}
CUTLASS_HOST_DEVICE
WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id)
: ref_(ref), iterations_(0) {
// See also:
// https://docs.nvidia.com/cuda/archive/11.7.1/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-1688
// 16x8x8: kAccessesInner = 1 (1 ldmatrix.x4)
// 16x8x16: kAccessesInner = 2 (2 ldmatrix.x4)
int ldsm_vec_num = (lane_id >> 3);
if (kOperand == Operand::kA) {
origin_ = MatrixCoord(lane_id % 8, 0);
static_assert(
InstructionCount::kRow * kTilesPerInstruction == 4,
"can't use ldmatrix.x4");
int access_m_idx = ldsm_vec_num % kTilesPerInstruction;
int inner_idx = (ldsm_vec_num / kTilesPerInstruction) % kAccessesInner;
int inst_m_idx = ldsm_vec_num / (kTilesPerInstruction * kAccessesInner);
MatrixCoord offset(
access_m_idx * 8 + inst_m_idx * InstructionShape::kRow,
inner_idx * 4 * kElementsPerAccess);
if (kTranspose) {
offset = MatrixCoord(offset.column(), offset.row());
}
origin_ += offset;
} else {
// XXX: This is not tested or used
origin_ = MatrixCoord(0, lane_id % 8);
static_assert(InstructionCount::kColumn * kAccessesInner == 4, "");
CUTLASS_PRAGMA_UNROLL
for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn;
++inst_n_idx) {
CUTLASS_PRAGMA_UNROLL
for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) {
int access_idx = inner_idx + kAccessesInner * inst_n_idx;
MatrixCoord offset(
inner_idx * 4 * kElementsPerAccess, inst_n_idx * 8);
if (access_idx == ldsm_vec_num) {
if (kTranspose) {
offset = MatrixCoord(offset.column(), offset.row());
}
origin_ += offset;
}
}
}
}
ref_.add_coord_offset(origin_);
}
/// Advances an iterator along logical dimensions of matrix in units of whole
/// tiles
CUTLASS_HOST_DEVICE
WarpIteratorFromSmem& add_tile_offset(TensorCoord const& tile_offset) {
TensorCoord coord_offset(
tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn);
if (kTranspose) {
coord_offset = TensorCoord{coord_offset.column(), coord_offset.row()};
}
origin_ += coord_offset;
ref_.add_coord_offset(coord_offset);
return *this;
}
/// Advances the iterator along the advance dimension
CUTLASS_DEVICE
void advance() {
if (kOperand == Operand::kA) {
add_tile_offset({0, 1});
} else {
add_tile_offset({1, 0});
}
iterations_ = 0;
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free