MatrixShape Class — pytorch Architecture
Architecture documentation for the MatrixShape class in mma_from_smem.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_from_smem.h lines 1471–1627
template < /// Size of the matrix to load (concept: MatrixShape)
typename Shape_,
/// Element type
typename Element_,
/// Layout of operand in memory
typename Layout_,
/// Shape of one matrix product operation (concept: MatrixShape)
typename InstructionShape_,
/// Interval between adjacent *MMA instructions (in units of MMA
/// instructions, concept: MatrixShape)
typename OpDelta_,
typename Operator,
typename scalar_t,
typename WarpShape_,
typename ThreadblockShape_>
struct B2bGemm<
cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator<
Shape_,
Element_,
Layout_,
InstructionShape_,
OpDelta_>,
Operator,
scalar_t,
WarpShape_,
ThreadblockShape_> {
using IteratorC =
typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator<
Shape_,
Element_,
Layout_,
InstructionShape_,
OpDelta_>;
using FragmentC = typename IteratorC::Fragment;
using InstructionShape = InstructionShape_;
using WarpShape = WarpShape_;
using ThreadblockShape = ThreadblockShape_;
using accum_t = Element_;
using lse_scalar_t = float;
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
// Iterator to load accumulators (results of matmul in registers)
using FragmentIteratorAccumulator =
cutlass::epilogue::warp::FragmentIteratorTensorOp<
WarpShape,
InstructionShape,
accum_t,
typename Operator::Policy::Operator::FragmentC,
cutlass::layout::RowMajor>;
// Iterator to store to shared-memory
using SmemIteratorD0 = typename cutlass::epilogue::warp::TileIteratorTensorOp<
WarpShape,
InstructionShape,
scalar_t, // accum_t,
SmemAccumulatorLayout>;
using AccumulatorSharedStorage =
cutlass::gemm::threadblock::AccumulatorSharedStorage<
ThreadblockShape,
typename SmemIteratorD0::Element,
typename SmemIteratorD0::TensorLayout,
typename SmemIteratorD0::Padding>;
// We need to provide an operation for the epilogue. Let's create an
// operation that does nothing (ScaleType::Nothing), just converts
// from accum_t (float) -> scalar_t (can be half)
using OutputOpNoOp = cutlass::epilogue::thread::LinearCombination<
typename SmemIteratorD0::Element, // ElementOutput
FragmentIteratorAccumulator::Fragment::kElements,
accum_t, // ElementAccumulator
typename SmemIteratorD0::Element, // ElementCompute
cutlass::epilogue::thread::ScaleType::Nothing>;
using Epilogue = cutlass::epilogue::threadblock::EpilogueSmemAccumulator<
SmemIteratorD0,
FragmentIteratorAccumulator,
SmemIteratorD0, // ScaleBiasIterator - not used
OutputOpNoOp>;
// Epilogue 2: with LSE (for backwards pass)
static int const kElementsPerAccess = 2; // TODO: Why 2?
using IteratorAccumulatorLSE =
cutlass::transform::threadblock::VectorIterator<
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
// Shape
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kN>,
// WarpShape
cutlass::MatrixShape<WarpShape::kM, WarpShape::kN>,
lse_scalar_t,
cutlass::layout::RowMajor,
kElementsPerAccess>>;
using EpilogueOpApplyLSE = cutlass::epilogue::thread::ApplyLogSumExp<
scalar_t, // ElementOutput_
lse_scalar_t, // ElementLSE_
accum_t, // ElementAccumulator_
accum_t, // ElementCompute_
128 / cutlass::sizeof_bits<scalar_t>::value
// FragmentIteratorAccumulator::Fragment::kElements
// InstructionShape::kM * InstructionShape::kN / 32
>;
using EpilogueWithLSE =
cutlass::epilogue::threadblock::EpilogueSmemAccumulator<
SmemIteratorD0,
FragmentIteratorAccumulator,
IteratorAccumulatorLSE,
EpilogueOpApplyLSE>;
static void CUTLASS_DEVICE accumToSmem(
AccumulatorSharedStorage& shared_storage,
FragmentC const& accum,
int lane_id,
cutlass::MatrixCoord const& tile_coords) {
SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id);
smem_iterator_attn.add_tile_offset(
tile_coords *
cutlass::MatrixCoord{
SmemIteratorD0::TileIterations::kRow,
SmemIteratorD0::TileIterations::kColumn});
Epilogue epilogue;
epilogue(OutputOpNoOp({}), smem_iterator_attn, accum);
}
static void CUTLASS_DEVICE accumApplyLSEToSmem(
AccumulatorSharedStorage& shared_storage,
FragmentC& accum,
lse_scalar_t const* lse,
int32_t lse_extents,
int thread_id,
int warp_id,
int lane_id,
cutlass::MatrixCoord const& tile_coords) {
constexpr int32_t kAlignLSE = 32;
IteratorAccumulatorLSE iterator_lse(
lse,
{(int32_t)0, (int32_t)ceil_div(lse_extents, kAlignLSE) * kAlignLSE},
thread_id,
warp_id,
cutlass::MatrixCoord{0, 0} // offset
);
SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id);
smem_iterator_attn.add_tile_offset(
tile_coords *
cutlass::MatrixCoord{
SmemIteratorD0::TileIterations::kRow,
SmemIteratorD0::TileIterations::kColumn});
EpilogueWithLSE epilogue;
EpilogueOpApplyLSE minus_lse_exp({});
epilogue(
minus_lse_exp,
smem_iterator_attn,
accum,
// scale - unused
iterator_lse,
// bias
iterator_lse);
}
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free