Home / Class/ MmaMultistageFromSharedMemory Class — pytorch Architecture

MmaMultistageFromSharedMemory Class — pytorch Architecture

Architecture documentation for the MmaMultistageFromSharedMemory class in mma_from_smem.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_from_smem.h lines 727–873

class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory<
                                          Shape1_,
                                          kMaxK_,
                                          Policy1_,
                                          Stages_,
                                          typename WarpIteratorA1_::Layout> {
 public:
  ///< Base class
  using Base = MmaBaseFromSharedMemory<
      Shape1_,
      kMaxK_,
      Policy1_,
      Stages_,
      typename WarpIteratorA1_::Layout>;

  ///< Size of the Gemm problem - concept: gemm::GemmShape<>
  using Shape1 = Shape1_;
  ///< Iterates over tiles of B operand in global memory
  using IteratorB1 = IteratorB1_;
  using IteratorB = IteratorB1;
  ///< Policy describing tuning details
  using Policy1 = Policy1_;

  using SmemIteratorB1 = SmemIteratorB1_;
  using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate
                                          ///< accumulator tile in shared memory
  static constexpr bool ScaleOperandA = ScaleOperandA_;

  ///< warp level iterator over A_scale matrix tile kept in shared memory.
  ///< if elementwise A scaling is disabled then everything this does is no-op.
  using WarpIteratorAScale = typename cutlass::platform::conditional<
      ScaleOperandA,
      WarpIteratorA1,
      NoOpWarpIteratorScale<typename WarpIteratorA1::TensorRef>>::type;
  ///< Data type of accumulator matrix
  using ElementC = ElementC_;
  ///< Layout of accumulator matrix
  using LayoutC = LayoutC_;

  static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1;
  static constexpr bool kSmemContainsEntireB = Base::kSmemContainsEntireB;

  //
  // Dependent types
  //

  /// Fragment of accumulator tile
  using FragmentC1 = typename Policy1::Operator::FragmentC;
  using FragmentC = FragmentC1;

  /// Warp-level Mma
  using Operator1 = typename Policy1::Operator;

  /// Minimum architecture is Sm80 to support cp.async
  using ArchTag = arch::Sm80;

  /// Complex transform on B operand
  static ComplexTransform const kTransformB1 = Operator1::kTransformB;

  /// Internal structure exposed for introspection.
  struct Detail {
    static_assert(
        Base::kWarpGemmIterations1 > 1,
        "The pipelined structure requires at least two warp-level "
        "GEMM operations.");

    /// Number of cp.async instructions to load one stage of operand B
    static int const TBLoadIterationsB1 =
        IteratorB1::ThreadMap::Iterations::kCount;

    /// Number of cp.async instructions to load on group of operand B
    static int const kAccessesPerGroupB1 =
        (TBLoadIterationsB1 + Base::kWarpGemmIterations1 - 1) /
        Base::kWarpGemmIterations1;
  };

  static constexpr int kNumStagesConcurrentLoad =
      kSmemContainsEntireB ? Base::kStages : Base::kStages - 1;

 private:
  using WarpLoadedFragmentA1 = typename Operator1::FragmentA;
  /// fragment of OperandA scale matrix. if operand A scaling is disabled this
  /// is (almost) empty.
  using WarpLoadedFragmentA1Scale = typename WarpIteratorAScale::Fragment;
  using WarpLoadedFragmentB1 = typename Operator1::FragmentB;
  using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA;
  using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB;

  /// applies elementwise scaling to fragment of A. if operand A scaling is
  /// disabled this is a no-op.
  using FragmentAScaler = FragmentElementwiseScaler<
      WarpLoadedFragmentA1,
      WarpLoadedFragmentA1Scale,
      ScaleOperandA>;

 private:
  //
  // Data members
  //

  /// Iterator to load a warp-scoped tile of A1 operand from intermediate
  /// accumulator tile
  WarpIteratorA1 warp_tile_iterator_A1_;

  /// Iterator to load a warp-scoped tile of A1_scale operand from shared memory
  /// if operand A scaling is disabled everything this does is a no-op.
  WarpIteratorAScale warp_tile_iterator_A1_scale_;

  /// Iterator to write threadblock-scoped tile of B operand to shared memory
  SmemIteratorB1 smem_iterator_B1_;

  bool prologue_done_;

 public:
  /// constructor for MMA with operand A scaling enabled.
  CUTLASS_DEVICE
  MmaMultistageFromSharedMemory(
      typename Base::TensorRefA a,
      typename Base::TensorRefA a_scale,
      typename Base::TensorRefB b_tile,
      int thread_idx,
      int warp_idx,
      int lane_idx)
      : Base(b_tile, thread_idx, warp_idx, lane_idx),
        warp_tile_iterator_A1_(a, lane_idx),
        warp_tile_iterator_A1_scale_(a_scale, lane_idx),
        smem_iterator_B1_(b_tile, thread_idx),
        prologue_done_(false) {
    // Compute warp location within threadblock tile by mapping the warp_id to
    // three coordinates:
    //   _m: the warp's position within the threadblock along the M dimension
    //   _n: the warp's position within the threadblock along the N dimension
    //   _k: the warp's position within the threadblock along the K dimension
    int warp_idx_mn_1 =
        warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN);
    int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN);
    int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM;
    int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM;

    // Add per-warp offsets in units of warp-level tiles
    warp_tile_iterator_A1_.add_tile_offset(
        {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1});
    warp_tile_iterator_A1_scale_.add_tile_offset(
        {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1});
    this->warp_tile_iterator_B_.add_tile_offset(
        {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1});
  }

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free