Home / Class/ Operand_ Class — pytorch Architecture

Operand_ Class — pytorch Architecture

Architecture documentation for the Operand_ class in warp_iterator_from_smem.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/transformers/cuda/mem_eff_attention/iterators/warp_iterator_from_smem.h lines 54–238

template <
    /// Operand identity
    Operand Operand_,
    /// Data type of A elements
    typename Element_,
    typename InstructionShape_,
    bool kTranspose = false>
class WarpIteratorFromSmem {
 public:
  /// Shape of tile to load (concept: MatrixShape)
  using Shape = cutlass::MatrixShape<32, 32>;

  /// Operand tag
  static Operand const kOperand = Operand_;
  static_assert(
      kOperand == Operand::kA,
      "No support for OperandB at the moment");

  /// Basic check
  static_assert(
      kOperand == Operand::kA || kOperand == Operand::kB,
      "WarpIteratorFromSmem may only be instantiated for A or B operands to warp-level Mma.");

  /// Element type
  using Element = Element_;
  static_assert(sizeof_bits<Element>::value == 16, "Only supported for half");

  /// Layout of source tile
  using Layout = cutlass::layout::RowMajor;

  /// Shape of one matrix product operation (concept: MatrixShape)
  using InstructionShape = InstructionShape_;
  static_assert(InstructionShape::kRow == 16, "Only supports 16x8x8 / 16x8x16");
  static_assert(
      InstructionShape::kColumn == 8 || InstructionShape::kColumn == 16,
      "Only supports 16x8x8 / 16x8x16");

  /// Delta between *MMA operations (in units of *MMA operations, concept:
  /// MatrixShape)
  static int const kOpDelta = 1;

  /// Number of participating threads
  static int const kThreads = 32;

  /// TensorRef type for loading element from a tensor
  using TensorRef = TensorRef<Element, Layout>;

  /// Index type
  using Index = typename TensorRef::Index;

  /// Long Index type
  using LongIndex = typename TensorRef::LongIndex;

  /// Coordinate for an element in the tensor
  using TensorCoord = typename TensorRef::TensorCoord;

  /// Number of elements accessed per Shared Memory load
  static int const kElementsPerAccess =
      (sizeof_bits<Element>::value >= 32 ? 1
                                         : 32 / sizeof_bits<Element>::value);

  using InstructionCount = MatrixShape<
      Shape::kRow / InstructionShape::kRow,
      Shape::kColumn / InstructionShape::kColumn>;

  static int const kIterations = (kOperand == Operand::kA)
      ? InstructionCount::kColumn
      : InstructionCount::kRow;

 public:
  //
  // Derived quantities
  //

  /// Fragment object holding a thread's part of a tile
  using Fragment = Array<
      Element,
      (kOperand == Operand::kA)
          ? (Shape::kRow* InstructionShape::kColumn / kThreads)
          : (Shape::kColumn* InstructionShape::kRow / kThreads)>;

  /// Memory access type
  // using AccessType = AlignedArray<Element, kElementsPerAccess>;
  using AccessType = Array<unsigned, 4>;

  static int constexpr kWarpShapeDivisibleInner =
      (kOperand == Operand::kA ? InstructionShape::kColumn
                               : InstructionShape::kRow);
  static int constexpr kAccessesInner =
      (kWarpShapeDivisibleInner / kElementsPerAccess) / 4;
  // Number of 32bits tiles to load per `ldmatrix`
  static int const kTilesPerInstruction = InstructionShape::kRow / 8;
  static_assert(kTilesPerInstruction == 2, "Only supports 16x8x16 and 16x8x8");

 private:
  /// Underlying tensor reference
  TensorRef ref_;

  /// Origin
  MatrixCoord origin_;

  /// Iterations in a tile
  int iterations_;

 public:
  /// Constructor from TensorRef
  CUTLASS_HOST_DEVICE
  WarpIteratorFromSmem(TensorRef const& ref, int lane_id)
      : WarpIteratorFromSmem(ref, {Shape::kRow, Shape::kColumn}, lane_id) {}
  CUTLASS_HOST_DEVICE
  WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id)
      : ref_(ref), iterations_(0) {
    // See also:
    // https://docs.nvidia.com/cuda/archive/11.7.1/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-1688
    // 16x8x8: kAccessesInner = 1 (1 ldmatrix.x4)
    // 16x8x16: kAccessesInner = 2 (2 ldmatrix.x4)
    int ldsm_vec_num = (lane_id >> 3);
    if (kOperand == Operand::kA) {
      origin_ = MatrixCoord(lane_id % 8, 0);
      static_assert(
          InstructionCount::kRow * kTilesPerInstruction == 4,
          "can't use ldmatrix.x4");
      int access_m_idx = ldsm_vec_num % kTilesPerInstruction;
      int inner_idx = (ldsm_vec_num / kTilesPerInstruction) % kAccessesInner;
      int inst_m_idx = ldsm_vec_num / (kTilesPerInstruction * kAccessesInner);
      MatrixCoord offset(
          access_m_idx * 8 + inst_m_idx * InstructionShape::kRow,
          inner_idx * 4 * kElementsPerAccess);
      if (kTranspose) {
        offset = MatrixCoord(offset.column(), offset.row());
      }
      origin_ += offset;
    } else {
      // XXX: This is not tested or used
      origin_ = MatrixCoord(0, lane_id % 8);
      static_assert(InstructionCount::kColumn * kAccessesInner == 4, "");
      CUTLASS_PRAGMA_UNROLL
      for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn;
           ++inst_n_idx) {
        CUTLASS_PRAGMA_UNROLL
        for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) {
          int access_idx = inner_idx + kAccessesInner * inst_n_idx;

          MatrixCoord offset(
              inner_idx * 4 * kElementsPerAccess, inst_n_idx * 8);

          if (access_idx == ldsm_vec_num) {
            if (kTranspose) {
              offset = MatrixCoord(offset.column(), offset.row());
            }
            origin_ += offset;
          }
        }
      }
    }

    ref_.add_coord_offset(origin_);
  }

  /// Advances an iterator along logical dimensions of matrix in units of whole
  /// tiles
  CUTLASS_HOST_DEVICE
  WarpIteratorFromSmem& add_tile_offset(TensorCoord const& tile_offset) {
    TensorCoord coord_offset(
        tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn);
    if (kTranspose) {
      coord_offset = TensorCoord{coord_offset.column(), coord_offset.row()};
    }
    origin_ += coord_offset;

    ref_.add_coord_offset(coord_offset);

    return *this;
  }

  /// Advances the iterator along the advance dimension
  CUTLASS_DEVICE
  void advance() {
    if (kOperand == Operand::kA) {
      add_tile_offset({0, 1});
    } else {
      add_tile_offset({1, 0});
    }

    iterations_ = 0;

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free