MmaMultistageFromSharedMemory Class — pytorch Architecture
Architecture documentation for the MmaMultistageFromSharedMemory class in mma_from_smem.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_from_smem.h lines 727–873
class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory<
Shape1_,
kMaxK_,
Policy1_,
Stages_,
typename WarpIteratorA1_::Layout> {
public:
///< Base class
using Base = MmaBaseFromSharedMemory<
Shape1_,
kMaxK_,
Policy1_,
Stages_,
typename WarpIteratorA1_::Layout>;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape1 = Shape1_;
///< Iterates over tiles of B operand in global memory
using IteratorB1 = IteratorB1_;
using IteratorB = IteratorB1;
///< Policy describing tuning details
using Policy1 = Policy1_;
using SmemIteratorB1 = SmemIteratorB1_;
using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate
///< accumulator tile in shared memory
static constexpr bool ScaleOperandA = ScaleOperandA_;
///< warp level iterator over A_scale matrix tile kept in shared memory.
///< if elementwise A scaling is disabled then everything this does is no-op.
using WarpIteratorAScale = typename cutlass::platform::conditional<
ScaleOperandA,
WarpIteratorA1,
NoOpWarpIteratorScale<typename WarpIteratorA1::TensorRef>>::type;
///< Data type of accumulator matrix
using ElementC = ElementC_;
///< Layout of accumulator matrix
using LayoutC = LayoutC_;
static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1;
static constexpr bool kSmemContainsEntireB = Base::kSmemContainsEntireB;
//
// Dependent types
//
/// Fragment of accumulator tile
using FragmentC1 = typename Policy1::Operator::FragmentC;
using FragmentC = FragmentC1;
/// Warp-level Mma
using Operator1 = typename Policy1::Operator;
/// Minimum architecture is Sm80 to support cp.async
using ArchTag = arch::Sm80;
/// Complex transform on B operand
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
/// Internal structure exposed for introspection.
struct Detail {
static_assert(
Base::kWarpGemmIterations1 > 1,
"The pipelined structure requires at least two warp-level "
"GEMM operations.");
/// Number of cp.async instructions to load one stage of operand B
static int const TBLoadIterationsB1 =
IteratorB1::ThreadMap::Iterations::kCount;
/// Number of cp.async instructions to load on group of operand B
static int const kAccessesPerGroupB1 =
(TBLoadIterationsB1 + Base::kWarpGemmIterations1 - 1) /
Base::kWarpGemmIterations1;
};
static constexpr int kNumStagesConcurrentLoad =
kSmemContainsEntireB ? Base::kStages : Base::kStages - 1;
private:
using WarpLoadedFragmentA1 = typename Operator1::FragmentA;
/// fragment of OperandA scale matrix. if operand A scaling is disabled this
/// is (almost) empty.
using WarpLoadedFragmentA1Scale = typename WarpIteratorAScale::Fragment;
using WarpLoadedFragmentB1 = typename Operator1::FragmentB;
using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA;
using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB;
/// applies elementwise scaling to fragment of A. if operand A scaling is
/// disabled this is a no-op.
using FragmentAScaler = FragmentElementwiseScaler<
WarpLoadedFragmentA1,
WarpLoadedFragmentA1Scale,
ScaleOperandA>;
private:
//
// Data members
//
/// Iterator to load a warp-scoped tile of A1 operand from intermediate
/// accumulator tile
WarpIteratorA1 warp_tile_iterator_A1_;
/// Iterator to load a warp-scoped tile of A1_scale operand from shared memory
/// if operand A scaling is disabled everything this does is a no-op.
WarpIteratorAScale warp_tile_iterator_A1_scale_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB1 smem_iterator_B1_;
bool prologue_done_;
public:
/// constructor for MMA with operand A scaling enabled.
CUTLASS_DEVICE
MmaMultistageFromSharedMemory(
typename Base::TensorRefA a,
typename Base::TensorRefA a_scale,
typename Base::TensorRefB b_tile,
int thread_idx,
int warp_idx,
int lane_idx)
: Base(b_tile, thread_idx, warp_idx, lane_idx),
warp_tile_iterator_A1_(a, lane_idx),
warp_tile_iterator_A1_scale_(a_scale, lane_idx),
smem_iterator_B1_(b_tile, thread_idx),
prologue_done_(false) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn_1 =
warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN);
int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN);
int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM;
int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM;
// Add per-warp offsets in units of warp-level tiles
warp_tile_iterator_A1_.add_tile_offset(
{warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1});
warp_tile_iterator_A1_scale_.add_tile_offset(
{warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free