Home / Class/ accumToSmem Class — pytorch Architecture

accumToSmem Class — pytorch Architecture

Architecture documentation for the accumToSmem class in mma_from_smem.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_from_smem.h lines 1808–1942

template <
    typename Operator,
    typename OperatorPolicy,
    typename scalar_t,
    typename WarpShape_,
    typename ThreadblockShape_>
struct B2bGemm<
    cutlass::gemm::warp::MmaSimtTileIterator<
        cutlass::MatrixShape<32, 32>,
        cutlass::gemm::Operand::kC,
        float,
        cutlass::layout::RowMajor,
        OperatorPolicy,
        1,
        1>,
    Operator,
    scalar_t,
    WarpShape_,
    ThreadblockShape_> {
  using IteratorC = cutlass::gemm::warp::MmaSimtTileIterator<
      cutlass::MatrixShape<32, 32>,
      cutlass::gemm::Operand::kC,
      float,
      cutlass::layout::RowMajor,
      OperatorPolicy,
      1,
      1>;
  using accum_t = typename IteratorC::Element;
  using WarpShape = WarpShape_;
  using ThreadblockShape = ThreadblockShape_;
  using FragmentC = typename IteratorC::Fragment;
  using lse_scalar_t = float;

  // Storage in shared-memory for Q.Kt
  using AccumulatorSharedStorage =
      cutlass::gemm::threadblock::AccumulatorSharedStorage<
          ThreadblockShape,
          scalar_t,
          cutlass::layout::ColumnMajor,
          cutlass::MatrixShape<0, 0> // Padding
          >;

  static void CUTLASS_DEVICE accumToSmem(
      AccumulatorSharedStorage& shared_storage,
      FragmentC const& accum,
      int lane_id,
      cutlass::MatrixCoord const& tile_coords) {
    using Policy = typename IteratorC::Policy;
    using Element = typename IteratorC::Element;
    using Iterations = typename IteratorC::Iterations;
    using Delta = typename IteratorC::Delta;

    auto ref_ = shared_storage.accum_ref();
    // ctor - MmaSimtTileIterator
    // compute offset based on thread ID and lane layout
    typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();

    MatrixCoord lane_offset = lane_layout.inverse(lane_id) *
        MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN);

    ref_.add_coord_offset(lane_offset);

    // Tile offset
    ref_.add_coord_offset(
        tile_coords *
        cutlass::MatrixCoord(
            {IteratorC::Shape::kRow, IteratorC::Shape::kColumn}));

    // store - MmaSimtTileIterator
    CUTLASS_PRAGMA_UNROLL
    for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {
      CUTLASS_PRAGMA_UNROLL
      for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) {
        CUTLASS_PRAGMA_UNROLL
        for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {
          CUTLASS_PRAGMA_UNROLL
          for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) {
            int r =
                Policy::LaneMmaShape::kM * (mma_m * Policy::WarpShape::kRow) +
                m;
            int c = mma_n * Delta::kColumn + n;
            int idx = n +
                Policy::LaneMmaShape::kN *
                    (mma_n +
                     Iterations::kColumn *
                         (m + mma_m * Policy::LaneMmaShape::kM));
            ref_.at({r, c}) = scalar_t(accum[idx]);
          }
        }
      }
    }
  }

  static void CUTLASS_DEVICE accumApplyLSEToSmem(
      AccumulatorSharedStorage& shared_storage,
      typename IteratorC::Fragment& accum,
      lse_scalar_t const* lse,
      int lse_extent,
      int thread_id,
      int warp_id,
      int lane_id,
      cutlass::MatrixCoord const& tile_coords) {
    // Non-optimized way to apply LSE to registers
    // NOTE: accum is attn.T
    // TODO: Optimize for each architecture
    static constexpr int WarpSize = 32;
    using AccumLambdaIterator =
        typename DefaultMmaAccumLambdaIterator<IteratorC, accum_t, WarpSize>::
            Iterator;
    auto lane_offset =
        AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords);

    cutlass::Array<lse_scalar_t, IteratorC::Fragment::kElements> lse_prefetched;
    lse_prefetched.clear();
    int rowIdx = 0;
    int colIdx = 0;
    AccumLambdaIterator::iterateRows(
        lane_offset,
        [&](int accum_m) {
          ++rowIdx;
          colIdx = 0;
        },
        [&](int accum_m, int accum_n, int idx) {
          if (rowIdx == 1) {
            lse_prefetched[colIdx] = accum_n < lse_extent
                ? lse[accum_n]
                : platform::numeric_limits<accum_t>::infinity();
          }
          accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]);
          ++colIdx;
        },
        [&](int accum_m) {});
    accumToSmem(shared_storage, accum, lane_id, tile_coords);
  }
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free