accumToSmem Class — pytorch Architecture
Architecture documentation for the accumToSmem class in mma_from_smem.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_from_smem.h lines 1808–1942
template <
typename Operator,
typename OperatorPolicy,
typename scalar_t,
typename WarpShape_,
typename ThreadblockShape_>
struct B2bGemm<
cutlass::gemm::warp::MmaSimtTileIterator<
cutlass::MatrixShape<32, 32>,
cutlass::gemm::Operand::kC,
float,
cutlass::layout::RowMajor,
OperatorPolicy,
1,
1>,
Operator,
scalar_t,
WarpShape_,
ThreadblockShape_> {
using IteratorC = cutlass::gemm::warp::MmaSimtTileIterator<
cutlass::MatrixShape<32, 32>,
cutlass::gemm::Operand::kC,
float,
cutlass::layout::RowMajor,
OperatorPolicy,
1,
1>;
using accum_t = typename IteratorC::Element;
using WarpShape = WarpShape_;
using ThreadblockShape = ThreadblockShape_;
using FragmentC = typename IteratorC::Fragment;
using lse_scalar_t = float;
// Storage in shared-memory for Q.Kt
using AccumulatorSharedStorage =
cutlass::gemm::threadblock::AccumulatorSharedStorage<
ThreadblockShape,
scalar_t,
cutlass::layout::ColumnMajor,
cutlass::MatrixShape<0, 0> // Padding
>;
static void CUTLASS_DEVICE accumToSmem(
AccumulatorSharedStorage& shared_storage,
FragmentC const& accum,
int lane_id,
cutlass::MatrixCoord const& tile_coords) {
using Policy = typename IteratorC::Policy;
using Element = typename IteratorC::Element;
using Iterations = typename IteratorC::Iterations;
using Delta = typename IteratorC::Delta;
auto ref_ = shared_storage.accum_ref();
// ctor - MmaSimtTileIterator
// compute offset based on thread ID and lane layout
typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
MatrixCoord lane_offset = lane_layout.inverse(lane_id) *
MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN);
ref_.add_coord_offset(lane_offset);
// Tile offset
ref_.add_coord_offset(
tile_coords *
cutlass::MatrixCoord(
{IteratorC::Shape::kRow, IteratorC::Shape::kColumn}));
// store - MmaSimtTileIterator
CUTLASS_PRAGMA_UNROLL
for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) {
CUTLASS_PRAGMA_UNROLL
for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) {
int r =
Policy::LaneMmaShape::kM * (mma_m * Policy::WarpShape::kRow) +
m;
int c = mma_n * Delta::kColumn + n;
int idx = n +
Policy::LaneMmaShape::kN *
(mma_n +
Iterations::kColumn *
(m + mma_m * Policy::LaneMmaShape::kM));
ref_.at({r, c}) = scalar_t(accum[idx]);
}
}
}
}
}
static void CUTLASS_DEVICE accumApplyLSEToSmem(
AccumulatorSharedStorage& shared_storage,
typename IteratorC::Fragment& accum,
lse_scalar_t const* lse,
int lse_extent,
int thread_id,
int warp_id,
int lane_id,
cutlass::MatrixCoord const& tile_coords) {
// Non-optimized way to apply LSE to registers
// NOTE: accum is attn.T
// TODO: Optimize for each architecture
static constexpr int WarpSize = 32;
using AccumLambdaIterator =
typename DefaultMmaAccumLambdaIterator<IteratorC, accum_t, WarpSize>::
Iterator;
auto lane_offset =
AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords);
cutlass::Array<lse_scalar_t, IteratorC::Fragment::kElements> lse_prefetched;
lse_prefetched.clear();
int rowIdx = 0;
int colIdx = 0;
AccumLambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {
++rowIdx;
colIdx = 0;
},
[&](int accum_m, int accum_n, int idx) {
if (rowIdx == 1) {
lse_prefetched[colIdx] = accum_n < lse_extent
? lse[accum_n]
: platform::numeric_limits<accum_t>::infinity();
}
accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]);
++colIdx;
},
[&](int accum_m) {});
accumToSmem(shared_storage, accum, lane_id, tile_coords);
}
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free