Home / Class/ ScaleOperandA_ Class — pytorch Architecture

ScaleOperandA_ Class — pytorch Architecture

Architecture documentation for the ScaleOperandA_ class in mma_from_smem.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_from_smem.h lines 340–515

template <
    /// Size of the Gemm problem - concept: gemm::GemmShape<>
    typename Shape_,
    // BEGIN smem
    /// Iterates over the intermediate accumulator tile in shared memory
    typename WarpIteratorA_,
    /// whether or not to perform elementwise multiplication of A
    //  by another matrix (A_scale) that is also kept in shared memory prior
    //  to matmul A @ B
    bool ScaleOperandA_,
    /// Max GEMM problem size in K dimension
    int MaxK,
    /// Iterates over tiles of B operand in global memory
    //  (concept: ReadableTileIterator | ForwardTileIterator |
    //  MaskedTileIterator)
    typename IteratorB_,
    /// Iterates over tiles of B operand in shared memory
    /// (concept: WriteableTileIterator | RandomAccessTileIterator)
    typename SmemIteratorB_,
    /// Data type of accumulator matrix
    typename ElementC_,
    /// Data type of accumulator matrix
    typename LayoutC_,
    /// Policy describing tuning details (concept: MmaPolicy)
    typename Policy_,
    /// Transformation applied to B operand
    typename TransformB_ = NumericArrayConverter<
        typename SmemIteratorB_::Element,
        typename IteratorB_::Element,
        IteratorB_::Fragment::kElements>,
    /// Used for partial specialization
    typename Enable = bool>
class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
                                         Shape_,
                                         MaxK,
                                         Policy_,
                                         2,
                                         typename WarpIteratorA_::Layout> {
 public:
  ///< Base class
  using Base = MmaBaseFromSharedMemory<
      Shape_,
      MaxK,
      Policy_,
      2,
      typename WarpIteratorA_::Layout>;

  using Shape =
      Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
  static constexpr bool ScaleOperandA = ScaleOperandA_;

  using WarpIteratorA = WarpIteratorA_;
  ///< loads fragments of A_scale from shared memory if operand A scaling is
  ///< enabled. otherwise no-op.
  using WarpIteratorAScale = typename cutlass::platform::conditional<
      ScaleOperandA,
      WarpIteratorA,
      NoOpWarpIteratorScale<typename WarpIteratorA::TensorRef>>::type;

  using IteratorB =
      IteratorB_; ///< Iterates over tiles of B operand in global memory
  using ElementC = ElementC_; ///< Data type of accumulator matrix
  using LayoutC = LayoutC_; ///< Layout of accumulator matrix
  using Policy = Policy_; ///< Policy describing tuning details

  using SmemIteratorB = SmemIteratorB_;

  using TransformB = TransformB_;

  //
  // Dependent types
  //

  /// Fragment of operand B loaded from global memory
  using FragmentB = typename IteratorB::Fragment;

  /// Fragment of accumulator tile
  using FragmentC = typename Policy::Operator::FragmentC;

  /// Warp-level Mma
  using Operator = typename Policy::Operator;

  /// Obtain the arch tag from the warp-level operator
  using ArchTag = typename Policy::Operator::ArchTag;

  /// Complex transform on B operand
  static ComplexTransform const kTransformB = Operator::kTransformB;

  // statically assert kStages for MmaPipelined is two (Double-buffered pipeline)
  static_assert(
      (Base::kStages == 2),
      "MmaPipelined requires kStages set to value 2");

 private:
  using WarpFragmentA = typename Operator::FragmentA;

  /// fragment type of OperandA elementwise scaling matrix. (almost) empty
  /// if operand A scaling is disabled.
  using WarpFragmentAScale = typename WarpIteratorAScale::Fragment;

  using WarpFragmentB = typename Operator::FragmentB;

  /// applies scaling factor to operand A fragment if operand A scaling is
  /// enabled. otherwise no-op.
  using FragmentAScaler = FragmentElementwiseScaler<
      WarpFragmentA,
      WarpFragmentAScale,
      ScaleOperandA>;

 protected:
  // /// Iterator to write threadblock-scoped tile of A operand to shared memory
  // SmemIteratorA smem_iterator_A_;

  /// Iterator to write threadblock-scoped tile of B operand to shared memory
  SmemIteratorB smem_iterator_B_;

  /// Iterator to load a warp-scoped tile of A operand from intermediate
  /// accumulator tile
  WarpIteratorA warp_tile_iterator_A_;

  /// Iterator to load a warp-scoped tile of A_scale from intermediate
  /// accumulator tile (only used if ScaleOperandA_ is true)
  WarpIteratorAScale warp_tile_iterator_A_scale_;

 public:
  /// constructor for MMA with operand A scaling enabled.
  CUTLASS_DEVICE
  MmaPipelinedFromSharedMemory(
      typename Base::TensorRefA a, // Operand A in shared memory
      typename Base::TensorRefA a_scale, // Operand A_scale in shared memory
      typename Base::TensorRefB
          b_staging, // staging memory for loading tiles of B
      int thread_idx,
      int warp_idx,
      int lane_idx)
      : Base(b_staging, thread_idx, warp_idx, lane_idx),
        warp_tile_iterator_A_(a, lane_idx),
        warp_tile_iterator_A_scale_(a_scale, lane_idx),
        smem_iterator_B_(b_staging, thread_idx) {
    // Compute warp location within threadblock tile by mapping the warp_id to
    // three coordinates:
    //   _m: the warp's position within the threadblock along the M dimension
    //   _n: the warp's position within the threadblock along the N dimension
    //   _k: the warp's position within the threadblock along the K dimension
    int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
    int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
    int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
    int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;

    // Add per-warp offsets in units of warp-level tiles
    this->warp_tile_iterator_A_.add_tile_offset(
        {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
    this->warp_tile_iterator_A_scale_.add_tile_offset(
        {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
    this->warp_tile_iterator_B_.add_tile_offset(
        {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
  }

  /// Construct from tensor references
  CUTLASS_DEVICE
  MmaPipelinedFromSharedMemory(
      typename Base::TensorRefA a, ///< Operand A in shared memory
      typename Base::TensorRefB b_staging, ///< staging memory for loading B
      int thread_idx, ///< ID within the threadblock
      int warp_idx, ///< ID of warp
      int lane_idx) ///< ID of each thread within a warp
      : Base(b_staging, thread_idx, warp_idx, lane_idx),
        warp_tile_iterator_A_(a, lane_idx),
        smem_iterator_B_(b_staging, thread_idx) {
    // Compute warp location within threadblock tile by mapping the warp_id to
    // three coordinates:
    //   _m: the warp's position within the threadblock along the M dimension
    //   _n: the warp's position within the threadblock along the N dimension
    //   _k: the warp's position within the threadblock along the K dimension

    int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free