Home / Class/ MatrixShape Class — pytorch Architecture

MatrixShape Class — pytorch Architecture

Architecture documentation for the MatrixShape class in mma_from_smem.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_from_smem.h lines 1471–1627

template < /// Size of the matrix to load (concept: MatrixShape)
    typename Shape_,
    /// Element type
    typename Element_,
    /// Layout of operand in memory
    typename Layout_,
    /// Shape of one matrix product operation (concept: MatrixShape)
    typename InstructionShape_,
    /// Interval between adjacent *MMA instructions (in units of MMA
    /// instructions, concept: MatrixShape)
    typename OpDelta_,
    typename Operator,
    typename scalar_t,
    typename WarpShape_,
    typename ThreadblockShape_>
struct B2bGemm<
    cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator<
        Shape_,
        Element_,
        Layout_,
        InstructionShape_,
        OpDelta_>,
    Operator,
    scalar_t,
    WarpShape_,
    ThreadblockShape_> {
  using IteratorC =
      typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator<
          Shape_,
          Element_,
          Layout_,
          InstructionShape_,
          OpDelta_>;
  using FragmentC = typename IteratorC::Fragment;
  using InstructionShape = InstructionShape_;
  using WarpShape = WarpShape_;
  using ThreadblockShape = ThreadblockShape_;
  using accum_t = Element_;
  using lse_scalar_t = float;

  using SmemAccumulatorLayout = cutlass::layout::RowMajor;

  // Iterator to load accumulators (results of matmul in registers)
  using FragmentIteratorAccumulator =
      cutlass::epilogue::warp::FragmentIteratorTensorOp<
          WarpShape,
          InstructionShape,
          accum_t,
          typename Operator::Policy::Operator::FragmentC,
          cutlass::layout::RowMajor>;

  // Iterator to store to shared-memory
  using SmemIteratorD0 = typename cutlass::epilogue::warp::TileIteratorTensorOp<
      WarpShape,
      InstructionShape,
      scalar_t, // accum_t,
      SmemAccumulatorLayout>;
  using AccumulatorSharedStorage =
      cutlass::gemm::threadblock::AccumulatorSharedStorage<
          ThreadblockShape,
          typename SmemIteratorD0::Element,
          typename SmemIteratorD0::TensorLayout,
          typename SmemIteratorD0::Padding>;
  // We need to provide an operation for the epilogue. Let's create an
  // operation that does nothing (ScaleType::Nothing), just converts
  // from accum_t (float) -> scalar_t (can be half)
  using OutputOpNoOp = cutlass::epilogue::thread::LinearCombination<
      typename SmemIteratorD0::Element, // ElementOutput
      FragmentIteratorAccumulator::Fragment::kElements,
      accum_t, // ElementAccumulator
      typename SmemIteratorD0::Element, // ElementCompute
      cutlass::epilogue::thread::ScaleType::Nothing>;
  using Epilogue = cutlass::epilogue::threadblock::EpilogueSmemAccumulator<
      SmemIteratorD0,
      FragmentIteratorAccumulator,
      SmemIteratorD0, // ScaleBiasIterator - not used
      OutputOpNoOp>;

  // Epilogue 2: with LSE (for backwards pass)
  static int const kElementsPerAccess = 2; // TODO: Why 2?
  using IteratorAccumulatorLSE =
      cutlass::transform::threadblock::VectorIterator<
          cutlass::transform::threadblock::PredicatedVectorAccessIterator<
              // Shape
              cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kN>,
              // WarpShape
              cutlass::MatrixShape<WarpShape::kM, WarpShape::kN>,
              lse_scalar_t,
              cutlass::layout::RowMajor,
              kElementsPerAccess>>;
  using EpilogueOpApplyLSE = cutlass::epilogue::thread::ApplyLogSumExp<
      scalar_t, // ElementOutput_
      lse_scalar_t, // ElementLSE_
      accum_t, // ElementAccumulator_
      accum_t, // ElementCompute_
      128 / cutlass::sizeof_bits<scalar_t>::value
      // FragmentIteratorAccumulator::Fragment::kElements
      // InstructionShape::kM * InstructionShape::kN / 32
      >;
  using EpilogueWithLSE =
      cutlass::epilogue::threadblock::EpilogueSmemAccumulator<
          SmemIteratorD0,
          FragmentIteratorAccumulator,
          IteratorAccumulatorLSE,
          EpilogueOpApplyLSE>;

  static void CUTLASS_DEVICE accumToSmem(
      AccumulatorSharedStorage& shared_storage,
      FragmentC const& accum,
      int lane_id,
      cutlass::MatrixCoord const& tile_coords) {
    SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id);
    smem_iterator_attn.add_tile_offset(
        tile_coords *
        cutlass::MatrixCoord{
            SmemIteratorD0::TileIterations::kRow,
            SmemIteratorD0::TileIterations::kColumn});
    Epilogue epilogue;
    epilogue(OutputOpNoOp({}), smem_iterator_attn, accum);
  }

  static void CUTLASS_DEVICE accumApplyLSEToSmem(
      AccumulatorSharedStorage& shared_storage,
      FragmentC& accum,
      lse_scalar_t const* lse,
      int32_t lse_extents,
      int thread_id,
      int warp_id,
      int lane_id,
      cutlass::MatrixCoord const& tile_coords) {
    constexpr int32_t kAlignLSE = 32;
    IteratorAccumulatorLSE iterator_lse(
        lse,
        {(int32_t)0, (int32_t)ceil_div(lse_extents, kAlignLSE) * kAlignLSE},
        thread_id,
        warp_id,
        cutlass::MatrixCoord{0, 0} // offset
    );

    SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id);
    smem_iterator_attn.add_tile_offset(
        tile_coords *
        cutlass::MatrixCoord{
            SmemIteratorD0::TileIterations::kRow,
            SmemIteratorD0::TileIterations::kColumn});
    EpilogueWithLSE epilogue;
    EpilogueOpApplyLSE minus_lse_exp({});
    epilogue(
        minus_lse_exp,
        smem_iterator_attn,
        accum,
        // scale - unused
        iterator_lse,
        // bias
        iterator_lse);
  }
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free